Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,19 @@ void CotabbyInferenceEngine::destroySequence(int32_t sequence_id) {

std::vector<int32_t> CotabbyInferenceEngine::tokenize(const char* text,
int text_length) const {
// Preserve the historical contract: add BOS per model metadata, treat any
// special-token text as plaintext (parse_special = false).
bool add_bos = impl_->vocab ? llama_vocab_get_add_bos(impl_->vocab) : false;
return tokenizeWithOptions(text, text_length, add_bos, false);
}

std::vector<int32_t> CotabbyInferenceEngine::tokenizeWithOptions(
const char* text, int text_length,
bool add_special, bool parse_special) const {
if (!impl_->vocab || !text || text_length <= 0) {
return {};
}

bool add_bos = llama_vocab_get_add_bos(impl_->vocab);
int capacity = text_length + 8;

while (true) {
Expand All @@ -554,8 +562,8 @@ std::vector<int32_t> CotabbyInferenceEngine::tokenize(const char* text,
static_cast<int32_t>(text_length),
tokens.data(),
static_cast<int32_t>(capacity),
add_bos,
false
add_special,
parse_special
);

if (n > 0) {
Expand All @@ -569,6 +577,64 @@ std::vector<int32_t> CotabbyInferenceEngine::tokenize(const char* text,
}
}

bool CotabbyInferenceEngine::hasChatTemplate() const {
if (!impl_->model) {
return false;
}
return llama_model_chat_template(impl_->model, /*name=*/nullptr) != nullptr;
}

std::string CotabbyInferenceEngine::applyChatTemplate(
const ChatMessage* messages, int message_count,
bool add_assistant) const {
if (!impl_->model || !messages || message_count <= 0) {
return {};
}

const char* tmpl = llama_model_chat_template(impl_->model, /*name=*/nullptr);
if (!tmpl) {
return {};
}

// `llama_chat_message` holds borrowed `const char*`. The backing
// std::strings live in `messages` for the duration of this call, so
// pointing at their c_str() is safe.
std::vector<llama_chat_message> chat;
chat.reserve(message_count);
size_t total_chars = 0;
for (int i = 0; i < message_count; ++i) {
chat.push_back(llama_chat_message{
messages[i].role.c_str(),
messages[i].content.c_str()
});
total_chars += messages[i].role.size() + messages[i].content.size();
}

// The header recommends an initial buffer of 2x the total message
// characters; grow and retry if the template expands beyond that.
std::vector<char> buf(std::max<size_t>(total_chars * 2, 256));
while (true) {
int32_t n = llama_chat_apply_template(
tmpl,
chat.data(),
chat.size(),
add_assistant,
buf.data(),
static_cast<int32_t>(buf.size())
);

if (n < 0) {
// Template not supported by llama.cpp's predefined list, or some
// other failure. Signal "fall back to the raw path".
return {};
}
if (static_cast<size_t>(n) <= buf.size()) {
return std::string(buf.data(), static_cast<size_t>(n));
}
buf.resize(static_cast<size_t>(n));
}
}

int CotabbyInferenceEngine::detokenize(int32_t token, char* buffer,
int buffer_size) const {
if (!impl_->vocab || !buffer || buffer_size <= 0) return 0;
Expand Down
33 changes: 33 additions & 0 deletions Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <cstdint>
#include <string>
#include <vector>
#include <swift/bridging>

Expand All @@ -13,6 +14,13 @@ struct SamplingConfig {
uint32_t seed;
};

/// One message in a chat-template conversation, mirroring `llama_chat_message`.
/// Roles are the usual "system" / "user" / "assistant". Owned by the caller.
struct ChatMessage {
std::string role;
std::string content;
};

struct SWIFT_SELF_CONTAINED SampleResult {
int32_t token;
const char* piece;
Expand Down Expand Up @@ -51,8 +59,33 @@ class CotabbyInferenceEngine {

// Tokenization (thread-safe, read-only on vocab)
std::vector<int32_t> tokenize(const char* text, int text_length) const;
// Like `tokenize`, but the caller controls BOS/EOS injection and whether
// special/control tokens in the text (e.g. chat-template markers like
// <|im_start|>) are recognized as their token IDs instead of plain text.
// The plain `tokenize` keeps `parse_special = false` for backward
// compatibility; the chat-template path needs `true` so rendered markers
// tokenize correctly.
std::vector<int32_t> tokenizeWithOptions(const char* text, int text_length,
bool add_special,
bool parse_special) const;
int detokenize(int32_t token, char* buffer, int buffer_size) const;

// Chat templates
//
// `hasChatTemplate` reports whether the loaded model ships a chat template
// in its GGUF metadata. Instruct models (Qwen, Gemma, Llama) do; raw base
// models do not. Callers use this to decide between the structured
// chat-template prompt path and the legacy raw-continuation path so a
// user-supplied base model keeps working.
bool hasChatTemplate() const;
// Renders `messages` through the model's built-in chat template and returns
// the formatted prompt string. `add_assistant` appends the assistant-turn
// opening marker so the model continues as the assistant. Returns an empty
// string if no model is loaded, the model has no template, or formatting
// fails — callers must treat empty as "fall back to the raw path".
std::string applyChatTemplate(const ChatMessage* messages, int message_count,
bool add_assistant) const;

// Prompt decoding
EngineStatus decodePrompt(int32_t sequence_id,
const int32_t* tokens, int token_count,
Expand Down
44 changes: 44 additions & 0 deletions Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ final class LlamaMiddlewareTests: XCTestCase {
XCTAssertTrue(tokens.isEmpty)
}

func testTokenizeWithOptionsWithoutModelReturnsEmpty() {
let engine = CotabbyInferenceEngine()
let text = "hello"
let tokens = engine.tokenizeWithOptions(
text, Int32(text.utf8.count), false, true
)
XCTAssertTrue(tokens.isEmpty)
}

func testHasChatTemplateWithoutModelIsFalse() {
let engine = CotabbyInferenceEngine()
XCTAssertFalse(engine.hasChatTemplate())
}

func testApplyChatTemplateWithoutModelReturnsEmpty() {
let engine = CotabbyInferenceEngine()
var messages = [ChatMessage]()
messages.append(ChatMessage(role: "user", content: "hi"))
let rendered = messages.withUnsafeBufferPointer { buf in
engine.applyChatTemplate(buf.baseAddress, Int32(buf.count), true)
}
XCTAssertTrue(rendered.isEmpty)
}

func testDiagnosticsDefaultToZero() {
let engine = CotabbyInferenceEngine()
XCTAssertEqual(engine.getContextWindowTokens(), 0)
Expand Down Expand Up @@ -94,6 +118,26 @@ final class LlamaMiddlewareTests: XCTestCase {
let tokens = engine.tokenize(prompt, Int32(prompt.utf8.count))
XCTAssertFalse(tokens.isEmpty)

// Chat-template path: instruct models ship a template; if present,
// rendering a simple conversation must produce a non-empty prompt that
// tokenizes (with parse_special) to a non-empty token list.
if engine.hasChatTemplate() {
var messages = [ChatMessage]()
messages.append(ChatMessage(role: "system", content: "You complete text."))
messages.append(ChatMessage(role: "user", content: "The quick brown"))
let rendered = messages.withUnsafeBufferPointer { buf in
engine.applyChatTemplate(buf.baseAddress, Int32(buf.count), true)
}
// applyChatTemplate returns a C++ std::string; bridge to a Swift
// String before using String APIs like .utf8.
let renderedSwift = String(rendered)
XCTAssertFalse(renderedSwift.isEmpty, "Model reports a template but rendering was empty")
let templated = engine.tokenizeWithOptions(
renderedSwift, Int32(renderedSwift.utf8.count), false, true
)
XCTAssertFalse(templated.isEmpty)
}

// Detokenize first token
var buf = [CChar](repeating: 0, count: 64)
let written = engine.detokenize(tokens[0], &buf, Int32(buf.count))
Expand Down
Loading