Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 179 additions & 1 deletion Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cmath>
#include <condition_variable>
#include <cstring>
#include <future>
Expand Down Expand Up @@ -51,6 +52,12 @@ struct SequenceState {
llama_token pending_input_token = 0;
bool has_pending_input = false;

// Set by setForceWordContinuation; consumed (and cleared) when the next seed token is sampled.
bool force_word_continuation = false;

// Log-probability of the seed token, computed at decodePrompt and returned with the seed.
float seed_logprob = 0.0f;

~SequenceState() {
if (sampler) { llama_sampler_free(sampler); }
}
Expand All @@ -66,7 +73,9 @@ struct SequenceState {
seed_token(o.seed_token),
has_seed_token(o.has_seed_token),
pending_input_token(o.pending_input_token),
has_pending_input(o.has_pending_input) {
has_pending_input(o.has_pending_input),
force_word_continuation(o.force_word_continuation),
seed_logprob(o.seed_logprob) {
o.sampler = nullptr;
}
SequenceState& operator=(SequenceState&&) = delete;
Expand Down Expand Up @@ -119,6 +128,13 @@ struct CotabbyInferenceEngine::Impl {
int thread_count = 0;
int gpu_layer_count = 0;

// Token masks built once per model load (see buildTokenMasks). EOG tokens are deliberately
// excluded so the stop check still fires; they are never emitted as text. `starts_new_word`
// flags tokens whose decoded text begins with whitespace.
std::vector<llama_logit_bias> nonprintable_bias;
std::vector<llama_logit_bias> linebreak_bias;
std::vector<bool> starts_new_word;

// Public-facing sequence map (external int32_t IDs → state) and the
// internal `llama_seq_id` slot allocator.
mutable std::mutex sequences_mutex;
Expand Down Expand Up @@ -170,6 +186,23 @@ struct CotabbyInferenceEngine::Impl {
llama_sampler* chain = llama_sampler_chain_init(params);
if (!chain) return nullptr;

// Quality mask: control/unknown/unused tokens can never be sampled as visible text, and
// for single-line fields line-break tokens are masked too. Placed first so the -inf bias
// is absolute regardless of the temperature/top-k stages that follow. EOG is intentionally
// left sampleable so the stop check in processBatch/sampleNext still fires.
std::vector<llama_logit_bias> mask = nonprintable_bias;
if (cfg.single_line && !linebreak_bias.empty()) {
mask.insert(mask.end(), linebreak_bias.begin(), linebreak_bias.end());
}
if (!mask.empty()) {
auto* bias = llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
static_cast<int32_t>(mask.size()),
mask.data()
);
if (bias) llama_sampler_chain_add(chain, bias);
}

if (cfg.repetition_penalty > 1.0f) {
auto* pen = llama_sampler_init_penalties(
64, cfg.repetition_penalty, 0.0f, 0.0f
Expand Down Expand Up @@ -211,6 +244,88 @@ struct CotabbyInferenceEngine::Impl {
return chain;
}

// Classifies the whole vocabulary once per model load. Populates the logit-bias masks and the
// whitespace-leading flag used for first-token word continuation. Doing it here keeps the hot
// sampling path free of any per-token tokenizer calls.
void buildTokenMasks() {
nonprintable_bias.clear();
linebreak_bias.clear();
starts_new_word.clear();
if (!vocab) return;

const int32_t n = llama_vocab_n_tokens(vocab);
starts_new_word.assign(static_cast<size_t>(n), false);

char piece[64];
for (llama_token t = 0; t < n; ++t) {
const bool is_eog = llama_vocab_is_eog(vocab, t);

// Nonprintable: control (non-EOG), unknown, and unused tokens must never appear as
// text. EOG stays sampleable so the stop check can recognize a natural end of output.
if (!is_eog) {
const enum llama_token_attr attr = llama_vocab_get_attr(vocab, t);
const bool junk_attr =
(attr & (LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_UNUSED)) != 0;
if (llama_vocab_is_control(vocab, t) || junk_attr) {
nonprintable_bias.push_back({ t, -INFINITY });
}
}

const int written = llama_token_to_piece(vocab, t, piece, sizeof(piece), 0, false);
if (written <= 0) {
continue;
}
const char first = piece[0];
if (first == ' ' || first == '\t' || first == '\n' || first == '\r') {
starts_new_word[static_cast<size_t>(t)] = true;
}
if (!is_eog) {
for (int i = 0; i < written; ++i) {
if (piece[i] == '\n' || piece[i] == '\r') {
linebreak_bias.push_back({ t, -INFINITY });
break;
}
}
}
}
}

// Masks every "starts a new word" token (decoded text begins with whitespace) in the logits
// row so the next sampled token must continue the current word. Used for the first token only.
void maskNewWordStarts(int logits_row) {
if (!shared_ctx || !vocab) return;
float* logits = llama_get_logits_ith(shared_ctx, logits_row);
if (!logits) return;
const int32_t n = static_cast<int32_t>(starts_new_word.size());
for (llama_token t = 0; t < n; ++t) {
if (starts_new_word[static_cast<size_t>(t)]) {
logits[t] = -INFINITY;
}
}
}

// Log-probability of `token` under the raw model distribution at `logits_row`, used as a
// confidence signal. Two O(vocab) passes; only invoked on the autocomplete path.
float computeLogprob(int logits_row, llama_token token) const {
if (!shared_ctx || !vocab) return 0.0f;
const float* logits = llama_get_logits_ith(shared_ctx, logits_row);
if (!logits) return 0.0f;
const int32_t n = llama_vocab_n_tokens(vocab);
if (token < 0 || token >= n) return 0.0f;
float maxLogit = -INFINITY;
for (llama_token t = 0; t < n; ++t) {
if (logits[t] > maxLogit) { maxLogit = logits[t]; }
}
double sumExp = 0.0;
for (llama_token t = 0; t < n; ++t) {
sumExp += std::exp(static_cast<double>(logits[t] - maxLogit));
}
if (!(sumExp > 0.0)) return 0.0f;
return static_cast<float>(
static_cast<double>(logits[token] - maxLogit) - std::log(sumExp)
);
}

void destroyAllSequences() {
std::lock_guard<std::mutex> lock(sequences_mutex);
for (auto& [id, seq] : sequences) {
Expand Down Expand Up @@ -339,6 +454,7 @@ struct CotabbyInferenceEngine::Impl {
r.token = next;
r.piece = piece.c_str();
r.piece_length = static_cast<int>(piece.size());
r.logprob = computeLogprob(i, next);
}
}

Expand Down Expand Up @@ -440,6 +556,10 @@ EngineStatus CotabbyInferenceEngine::loadModel(const char* path, int gpu_layers,
return EngineStatus::error;
}

// Precompute the token masks now that the vocab is available; the sampler chains built in
// createSequence read these, and the hot path then needs no per-token tokenizer work.
impl_->buildTokenMasks();

impl_->startDecoderThread();
return EngineStatus::ok;
}
Expand Down Expand Up @@ -715,13 +835,22 @@ EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id,
llama_batch_free(batch);
seq->kv_position_count = total_end_position;

// First-token word-continuation constraint: when the caret is mid-word, mask new-word-start
// tokens for this seed only so the completion continues the current word instead of starting
// a new one. The flag clears after this single token.
if (seq->force_word_continuation) {
impl_->maskNewWordStarts(-1);
seq->force_word_continuation = false;
}

// Seed sample: take one token from the prompt's logits row right now,
// before any other sequence's decode can overwrite the shared logits
// buffer. The seed will be returned by the next sampleNext call as-is
// and feedback-decoded by the call after that.
llama_token seed = llama_sampler_sample(seq->sampler, impl_->shared_ctx, -1);
llama_sampler_accept(seq->sampler, seed);
seq->seed_token = seed;
seq->seed_logprob = impl_->computeLogprob(-1, seed);
seq->has_seed_token = true;
seq->has_pending_input = false;

Expand Down Expand Up @@ -792,6 +921,7 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) {
result.token = next;
result.piece = seq->last_piece.c_str();
result.piece_length = static_cast<int>(seq->last_piece.size());
result.logprob = seq->seed_logprob;
return result;
}

Expand Down Expand Up @@ -875,6 +1005,54 @@ int CotabbyInferenceEngine::getKVPositionCount(int32_t sequence_id) const {
return seq ? seq->kv_position_count : 0;
}

void CotabbyInferenceEngine::setForceWordContinuation(int32_t sequence_id, bool enabled) {
if (!impl_) return;
SequenceState* seq = impl_->findSequence(sequence_id);
if (seq) {
seq->force_word_continuation = enabled;
}
}

// ---------------------------------------------------------------------------
// KV state snapshot / restore
//
// Thin wrappers over llama's single-sequence state copy. They serialize with the decoder thread
// via decode_mutex so a snapshot or restore never races an in-flight llama_decode. Callers pair a
// snapshot with the KV position count they observed and pass it back on restore, so this engine's
// own position bookkeeping stays consistent with the restored KV cache.
// ---------------------------------------------------------------------------

size_t CotabbyInferenceEngine::snapshotSize(int32_t sequence_id) const {
if (!impl_ || !impl_->shared_ctx) return 0;
const SequenceState* seq = impl_->findSequence(sequence_id);
if (!seq) return 0;
return llama_state_seq_get_size(impl_->shared_ctx, seq->seq_id);
}

size_t CotabbyInferenceEngine::snapshotSequence(int32_t sequence_id, uint8_t* dst, size_t capacity) {
if (!impl_ || !impl_->shared_ctx || !dst) return 0;
SequenceState* seq = impl_->findSequence(sequence_id);
if (!seq) return 0;
std::lock_guard<std::mutex> lock(impl_->decode_mutex);
return llama_state_seq_get_data(impl_->shared_ctx, dst, capacity, seq->seq_id);
}

bool CotabbyInferenceEngine::restoreSequence(int32_t sequence_id, const uint8_t* src,
size_t size, int position_count) {
if (!impl_ || !impl_->shared_ctx || !src) return false;
SequenceState* seq = impl_->findSequence(sequence_id);
if (!seq) return false;
std::lock_guard<std::mutex> lock(impl_->decode_mutex);
const size_t read = llama_state_seq_set_data(impl_->shared_ctx, src, size, seq->seq_id);
if (read == 0) return false;
seq->kv_position_count = position_count;
// The restored blob invalidates any seed/pending token captured before; force the caller to
// re-prime via decodePrompt before the next sampleNext.
seq->has_seed_token = false;
seq->has_pending_input = false;
return true;
}

// ---------------------------------------------------------------------------
// Cancellation
// ---------------------------------------------------------------------------
Expand Down
22 changes: 22 additions & 0 deletions Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <vector>
#include <swift/bridging>
Expand All @@ -11,6 +12,10 @@ struct SamplingConfig {
float min_p;
float repetition_penalty;
uint32_t seed;

// When true, tokens that introduce a line break are masked from sampling so single-line
// fields never receive a multi-line completion. Defaults to false to preserve prior behavior.
bool single_line = false;
};

struct SWIFT_SELF_CONTAINED SampleResult {
Expand All @@ -19,6 +24,9 @@ struct SWIFT_SELF_CONTAINED SampleResult {
int piece_length;
bool is_eos;
bool was_cancelled;
// Log-probability of the chosen token under the raw model distribution (<= 0). Used as a
// confidence signal; 0 for the EOS/cancelled cases where it carries no meaning.
float logprob;
};

enum class EngineStatus : int {
Expand Down Expand Up @@ -101,6 +109,20 @@ class CotabbyInferenceEngine {
bool trimKV(int32_t sequence_id, int keep_positions);
int getKVPositionCount(int32_t sequence_id) const;

// Constrains the FIRST token of the next generation on `sequence_id` to continue the current
// word: tokens whose decoded text begins with whitespace are masked for that one token, then
// the constraint clears. Set this before `decodePrompt`, which samples the first (seed) token.
void setForceWordContinuation(int32_t sequence_id, bool enabled);

// Single-sequence KV state snapshot/restore. `snapshotSize` reports the buffer size needed,
// `snapshotSequence` copies the sequence's KV state into `dst` (returns bytes written, 0 on
// failure), and `restoreSequence` loads a previously captured blob back and sets the KV
// position to `position_count` (returns false on failure). Blobs are model- and context-
// specific; never restore one across a model reload.
size_t snapshotSize(int32_t sequence_id) const;
size_t snapshotSequence(int32_t sequence_id, uint8_t* dst, size_t capacity);
bool restoreSequence(int32_t sequence_id, const uint8_t* src, size_t size, int position_count);

// Cancellation (thread-safe, non-blocking)
void cancelSequence(int32_t sequence_id);

Expand Down
Loading
Loading