diff --git a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp index d1d77c2..026db00 100644 --- a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp +++ b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp @@ -539,11 +539,19 @@ void CotabbyInferenceEngine::destroySequence(int32_t sequence_id) { std::vector CotabbyInferenceEngine::tokenize(const char* text, int text_length) const { + // Preserve the historical contract: add BOS per model metadata, treat any + // special-token text as plaintext (parse_special = false). + bool add_bos = impl_->vocab ? llama_vocab_get_add_bos(impl_->vocab) : false; + return tokenizeWithOptions(text, text_length, add_bos, false); +} + +std::vector CotabbyInferenceEngine::tokenizeWithOptions( + const char* text, int text_length, + bool add_special, bool parse_special) const { if (!impl_->vocab || !text || text_length <= 0) { return {}; } - bool add_bos = llama_vocab_get_add_bos(impl_->vocab); int capacity = text_length + 8; while (true) { @@ -554,8 +562,8 @@ std::vector CotabbyInferenceEngine::tokenize(const char* text, static_cast(text_length), tokens.data(), static_cast(capacity), - add_bos, - false + add_special, + parse_special ); if (n > 0) { @@ -569,6 +577,64 @@ std::vector CotabbyInferenceEngine::tokenize(const char* text, } } +bool CotabbyInferenceEngine::hasChatTemplate() const { + if (!impl_->model) { + return false; + } + return llama_model_chat_template(impl_->model, /*name=*/nullptr) != nullptr; +} + +std::string CotabbyInferenceEngine::applyChatTemplate( + const ChatMessage* messages, int message_count, + bool add_assistant) const { + if (!impl_->model || !messages || message_count <= 0) { + return {}; + } + + const char* tmpl = llama_model_chat_template(impl_->model, /*name=*/nullptr); + if (!tmpl) { + return {}; + } + + // `llama_chat_message` holds borrowed `const char*`. The backing + // std::strings live in `messages` for the duration of this call, so + // pointing at their c_str() is safe. + std::vector chat; + chat.reserve(message_count); + size_t total_chars = 0; + for (int i = 0; i < message_count; ++i) { + chat.push_back(llama_chat_message{ + messages[i].role.c_str(), + messages[i].content.c_str() + }); + total_chars += messages[i].role.size() + messages[i].content.size(); + } + + // The header recommends an initial buffer of 2x the total message + // characters; grow and retry if the template expands beyond that. + std::vector buf(std::max(total_chars * 2, 256)); + while (true) { + int32_t n = llama_chat_apply_template( + tmpl, + chat.data(), + chat.size(), + add_assistant, + buf.data(), + static_cast(buf.size()) + ); + + if (n < 0) { + // Template not supported by llama.cpp's predefined list, or some + // other failure. Signal "fall back to the raw path". + return {}; + } + if (static_cast(n) <= buf.size()) { + return std::string(buf.data(), static_cast(n)); + } + buf.resize(static_cast(n)); + } +} + int CotabbyInferenceEngine::detokenize(int32_t token, char* buffer, int buffer_size) const { if (!impl_->vocab || !buffer || buffer_size <= 0) return 0; diff --git a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h index dca6f12..9ca3ae4 100644 --- a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h +++ b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include @@ -13,6 +14,13 @@ struct SamplingConfig { uint32_t seed; }; +/// One message in a chat-template conversation, mirroring `llama_chat_message`. +/// Roles are the usual "system" / "user" / "assistant". Owned by the caller. +struct ChatMessage { + std::string role; + std::string content; +}; + struct SWIFT_SELF_CONTAINED SampleResult { int32_t token; const char* piece; @@ -51,8 +59,33 @@ class CotabbyInferenceEngine { // Tokenization (thread-safe, read-only on vocab) std::vector tokenize(const char* text, int text_length) const; + // Like `tokenize`, but the caller controls BOS/EOS injection and whether + // special/control tokens in the text (e.g. chat-template markers like + // <|im_start|>) are recognized as their token IDs instead of plain text. + // The plain `tokenize` keeps `parse_special = false` for backward + // compatibility; the chat-template path needs `true` so rendered markers + // tokenize correctly. + std::vector tokenizeWithOptions(const char* text, int text_length, + bool add_special, + bool parse_special) const; int detokenize(int32_t token, char* buffer, int buffer_size) const; + // Chat templates + // + // `hasChatTemplate` reports whether the loaded model ships a chat template + // in its GGUF metadata. Instruct models (Qwen, Gemma, Llama) do; raw base + // models do not. Callers use this to decide between the structured + // chat-template prompt path and the legacy raw-continuation path so a + // user-supplied base model keeps working. + bool hasChatTemplate() const; + // Renders `messages` through the model's built-in chat template and returns + // the formatted prompt string. `add_assistant` appends the assistant-turn + // opening marker so the model continues as the assistant. Returns an empty + // string if no model is loaded, the model has no template, or formatting + // fails — callers must treat empty as "fall back to the raw path". + std::string applyChatTemplate(const ChatMessage* messages, int message_count, + bool add_assistant) const; + // Prompt decoding EngineStatus decodePrompt(int32_t sequence_id, const int32_t* tokens, int token_count, diff --git a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift index 0826289..5f0d0b1 100644 --- a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift +++ b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift @@ -55,6 +55,30 @@ final class LlamaMiddlewareTests: XCTestCase { XCTAssertTrue(tokens.isEmpty) } + func testTokenizeWithOptionsWithoutModelReturnsEmpty() { + let engine = CotabbyInferenceEngine() + let text = "hello" + let tokens = engine.tokenizeWithOptions( + text, Int32(text.utf8.count), false, true + ) + XCTAssertTrue(tokens.isEmpty) + } + + func testHasChatTemplateWithoutModelIsFalse() { + let engine = CotabbyInferenceEngine() + XCTAssertFalse(engine.hasChatTemplate()) + } + + func testApplyChatTemplateWithoutModelReturnsEmpty() { + let engine = CotabbyInferenceEngine() + var messages = [ChatMessage]() + messages.append(ChatMessage(role: "user", content: "hi")) + let rendered = messages.withUnsafeBufferPointer { buf in + engine.applyChatTemplate(buf.baseAddress, Int32(buf.count), true) + } + XCTAssertTrue(rendered.isEmpty) + } + func testDiagnosticsDefaultToZero() { let engine = CotabbyInferenceEngine() XCTAssertEqual(engine.getContextWindowTokens(), 0) @@ -94,6 +118,26 @@ final class LlamaMiddlewareTests: XCTestCase { let tokens = engine.tokenize(prompt, Int32(prompt.utf8.count)) XCTAssertFalse(tokens.isEmpty) + // Chat-template path: instruct models ship a template; if present, + // rendering a simple conversation must produce a non-empty prompt that + // tokenizes (with parse_special) to a non-empty token list. + if engine.hasChatTemplate() { + var messages = [ChatMessage]() + messages.append(ChatMessage(role: "system", content: "You complete text.")) + messages.append(ChatMessage(role: "user", content: "The quick brown")) + let rendered = messages.withUnsafeBufferPointer { buf in + engine.applyChatTemplate(buf.baseAddress, Int32(buf.count), true) + } + // applyChatTemplate returns a C++ std::string; bridge to a Swift + // String before using String APIs like .utf8. + let renderedSwift = String(rendered) + XCTAssertFalse(renderedSwift.isEmpty, "Model reports a template but rendering was empty") + let templated = engine.tokenizeWithOptions( + renderedSwift, Int32(renderedSwift.utf8.count), false, true + ) + XCTAssertFalse(templated.isEmpty) + } + // Detokenize first token var buf = [CChar](repeating: 0, count: 64) let written = engine.detokenize(tokens[0], &buf, Int32(buf.count))