Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 34 additions & 40 deletions Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,55 +584,49 @@ bool CotabbyInferenceEngine::hasChatTemplate() const {
return llama_model_chat_template(impl_->model, /*name=*/nullptr) != nullptr;
}

std::string CotabbyInferenceEngine::applyChatTemplate(
const ChatMessage* messages, int message_count,
bool add_assistant) const {
if (!impl_->model || !messages || message_count <= 0) {
return {};
int CotabbyInferenceEngine::applyChatTemplate(
const char* system_text,
const char* user_text,
bool add_assistant,
char* buffer,
int buffer_size) const {
if (!impl_->model || !system_text || !user_text ||
!buffer || buffer_size <= 0) {
return 0;
}

const char* tmpl = llama_model_chat_template(impl_->model, /*name=*/nullptr);
if (!tmpl) {
return {};
return 0;
}

// `llama_chat_message` holds borrowed `const char*`. The backing
// std::strings live in `messages` for the duration of this call, so
// pointing at their c_str() is safe.
std::vector<llama_chat_message> chat;
chat.reserve(message_count);
size_t total_chars = 0;
for (int i = 0; i < message_count; ++i) {
chat.push_back(llama_chat_message{
messages[i].role.c_str(),
messages[i].content.c_str()
});
total_chars += messages[i].role.size() + messages[i].content.size();
}
// Borrowed `const char*` from the caller; valid for this call's duration.
llama_chat_message chat[2] = {
{ "system", system_text },
{ "user", user_text }
};

// The header recommends an initial buffer of 2x the total message
// characters; grow and retry if the template expands beyond that.
std::vector<char> buf(std::max<size_t>(total_chars * 2, 256));
while (true) {
int32_t n = llama_chat_apply_template(
tmpl,
chat.data(),
chat.size(),
add_assistant,
buf.data(),
static_cast<int32_t>(buf.size())
);
int32_t n = llama_chat_apply_template(
tmpl,
chat,
2,
add_assistant,
buffer,
static_cast<int32_t>(buffer_size)
);

if (n < 0) {
// Template not supported by llama.cpp's predefined list, or some
// other failure. Signal "fall back to the raw path".
return {};
}
if (static_cast<size_t>(n) <= buf.size()) {
return std::string(buf.data(), static_cast<size_t>(n));
}
buf.resize(static_cast<size_t>(n));
// Contract of llama_chat_apply_template: returns the total byte length of
// the formatted prompt; negative means the template is unsupported by
// llama.cpp's predefined list. A positive value larger than the buffer
// means the output did not fit and the caller must retry with a bigger
// buffer. Map all three onto this function's documented C-ABI contract.
if (n < 0) {
return 0; // genuine render failure → caller falls back to raw
}
if (n > buffer_size) {
return -n; // too small → -(required size); caller resizes and retries
}
return n; // success: n bytes written (n <= buffer_size)
}

int CotabbyInferenceEngine::detokenize(int32_t token, char* buffer,
Expand Down
33 changes: 18 additions & 15 deletions Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once
#include <cstdint>
#include <string>
#include <vector>
#include <swift/bridging>

Expand All @@ -14,13 +13,6 @@ struct SamplingConfig {
uint32_t seed;
};

/// One message in a chat-template conversation, mirroring `llama_chat_message`.
/// Roles are the usual "system" / "user" / "assistant". Owned by the caller.
struct ChatMessage {
std::string role;
std::string content;
};

struct SWIFT_SELF_CONTAINED SampleResult {
int32_t token;
const char* piece;
Expand Down Expand Up @@ -78,13 +70,24 @@ class CotabbyInferenceEngine {
// chat-template prompt path and the legacy raw-continuation path so a
// user-supplied base model keeps working.
bool hasChatTemplate() const;
// Renders `messages` through the model's built-in chat template and returns
// the formatted prompt string. `add_assistant` appends the assistant-turn
// opening marker so the model continues as the assistant. Returns an empty
// string if no model is loaded, the model has no template, or formatting
// fails — callers must treat empty as "fall back to the raw path".
std::string applyChatTemplate(const ChatMessage* messages, int message_count,
bool add_assistant) const;
// Renders a system + user turn through the model's built-in chat template
// into `buffer`. `add_assistant` appends the assistant-turn opening marker
// so the model continues as the assistant. Returns:
// > 0 : number of bytes written (<= buffer_size) — the formatted prompt.
// < 0 : -(required buffer size); the buffer was too small, retry at that size.
// = 0 : no model, no template, or render failure — caller falls back to raw.
//
// Autocomplete needs exactly one system turn (rules + context) and one user
// turn (the text to continue), so the signature takes those two directly
// rather than a message array. This buffer-based C ABI mirrors `detokenize`
// and deliberately avoids std::string / struct parameter and return types,
// so it bridges cleanly into the Swift objcxx interop mode the app target
// uses (where a std::string return does not bridge).
int applyChatTemplate(const char* system_text,
const char* user_text,
bool add_assistant,
char* buffer,
int buffer_size) const;

// Prompt decoding
EngineStatus decodePrompt(int32_t sequence_id,
Expand Down
38 changes: 22 additions & 16 deletions Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ final class LlamaMiddlewareTests: XCTestCase {
XCTAssertFalse(engine.hasChatTemplate())
}

func testApplyChatTemplateWithoutModelReturnsEmpty() {
func testApplyChatTemplateWithoutModelReturnsZero() {
let engine = CotabbyInferenceEngine()
var messages = [ChatMessage]()
messages.append(ChatMessage(role: "user", content: "hi"))
let rendered = messages.withUnsafeBufferPointer { buf in
engine.applyChatTemplate(buf.baseAddress, Int32(buf.count), true)
}
XCTAssertTrue(rendered.isEmpty)
var buffer = [CChar](repeating: 0, count: 256)
let written = engine.applyChatTemplate(
"You complete text.", "The quick brown", true, &buffer, Int32(buffer.count)
)
// No model loaded → 0 (caller falls back to the raw path).
XCTAssertEqual(written, 0)
}

func testDiagnosticsDefaultToZero() {
Expand Down Expand Up @@ -122,16 +122,22 @@ final class LlamaMiddlewareTests: XCTestCase {
// rendering a simple conversation must produce a non-empty prompt that
// tokenizes (with parse_special) to a non-empty token list.
if engine.hasChatTemplate() {
var messages = [ChatMessage]()
messages.append(ChatMessage(role: "system", content: "You complete text."))
messages.append(ChatMessage(role: "user", content: "The quick brown"))
let rendered = messages.withUnsafeBufferPointer { buf in
engine.applyChatTemplate(buf.baseAddress, Int32(buf.count), true)
// Render system + user through the model's template into a caller buffer.
var buffer = [CChar](repeating: 0, count: 4096)
let written = engine.applyChatTemplate(
"You complete text.", "The quick brown", true, &buffer, Int32(buffer.count)
)
XCTAssertGreaterThan(written, 0, "Model reports a template but rendering produced no bytes")

let rendered = buffer.prefix(Int(written)).withUnsafeBufferPointer { ptr in
String(
bytes: UnsafeRawBufferPointer(ptr),
encoding: .utf8
)
}
// applyChatTemplate returns a C++ std::string; bridge to a Swift
// String before using String APIs like .utf8.
let renderedSwift = String(rendered)
XCTAssertFalse(renderedSwift.isEmpty, "Model reports a template but rendering was empty")
let renderedSwift = try XCTUnwrap(rendered, "Rendered template was not valid UTF-8")
XCTAssertFalse(renderedSwift.isEmpty)

let templated = engine.tokenizeWithOptions(
renderedSwift, Int32(renderedSwift.utf8.count), false, true
)
Expand Down
Loading