From c2d3cd191230dfab4cf49c7974b927a086902df0 Mon Sep 17 00:00:00 2001 From: Jacob Fu <141651335+FuJacob@users.noreply.github.com> Date: Sun, 31 May 2026 11:56:37 -0700 Subject: [PATCH 1/2] Add generation-time quality controls: token masks, mid-word continuation, KV snapshot - buildTokenMasks classifies the vocab once per model load: control, unknown, and unused tokens get a -inf logit bias so they can never be sampled as visible text. EOG is deliberately left sampleable so the existing stop check still fires. - single_line SamplingConfig flag additionally masks line-break tokens so single-line fields never receive a multi-line completion. - setForceWordContinuation constrains the first (seed) token of a generation to continue the current word (masks whitespace-leading tokens) for mid-word carets. - snapshotSize / snapshotSequence / restoreSequence wrap llama single-sequence KV state copy, serialized with the decoder thread. Tests: new no-model and model-gated tests for the masks, mid-word continuation, and snapshot/restore round-trip. Also fixes a latent end-to-end test bug where detokenizing tokens[0] (BOS, a control token) returned 0 bytes. --- .../CotabbyInferenceEngine.cpp | 151 +++++++++++++++++- .../include/CotabbyInferenceEngine.h | 19 +++ .../LlamaMiddlewareTests.swift | 121 +++++++++++++- 3 files changed, 282 insertions(+), 9 deletions(-) diff --git a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp index fc64bea..abdcf1c 100644 --- a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp +++ b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +52,9 @@ struct SequenceState { llama_token pending_input_token = 0; bool has_pending_input = false; + // Set by setForceWordContinuation; consumed (and cleared) when the next seed token is sampled. + bool force_word_continuation = false; + ~SequenceState() { if (sampler) { llama_sampler_free(sampler); } } @@ -66,7 +70,8 @@ struct SequenceState { seed_token(o.seed_token), has_seed_token(o.has_seed_token), pending_input_token(o.pending_input_token), - has_pending_input(o.has_pending_input) { + has_pending_input(o.has_pending_input), + force_word_continuation(o.force_word_continuation) { o.sampler = nullptr; } SequenceState& operator=(SequenceState&&) = delete; @@ -119,6 +124,13 @@ struct CotabbyInferenceEngine::Impl { int thread_count = 0; int gpu_layer_count = 0; + // Token masks built once per model load (see buildTokenMasks). EOG tokens are deliberately + // excluded so the stop check still fires; they are never emitted as text. `starts_new_word` + // flags tokens whose decoded text begins with whitespace. + std::vector nonprintable_bias; + std::vector linebreak_bias; + std::vector starts_new_word; + // Public-facing sequence map (external int32_t IDs → state) and the // internal `llama_seq_id` slot allocator. mutable std::mutex sequences_mutex; @@ -170,6 +182,23 @@ struct CotabbyInferenceEngine::Impl { llama_sampler* chain = llama_sampler_chain_init(params); if (!chain) return nullptr; + // Quality mask: control/unknown/unused tokens can never be sampled as visible text, and + // for single-line fields line-break tokens are masked too. Placed first so the -inf bias + // is absolute regardless of the temperature/top-k stages that follow. EOG is intentionally + // left sampleable so the stop check in processBatch/sampleNext still fires. + std::vector mask = nonprintable_bias; + if (cfg.single_line && !linebreak_bias.empty()) { + mask.insert(mask.end(), linebreak_bias.begin(), linebreak_bias.end()); + } + if (!mask.empty()) { + auto* bias = llama_sampler_init_logit_bias( + llama_vocab_n_tokens(vocab), + static_cast(mask.size()), + mask.data() + ); + if (bias) llama_sampler_chain_add(chain, bias); + } + if (cfg.repetition_penalty > 1.0f) { auto* pen = llama_sampler_init_penalties( 64, cfg.repetition_penalty, 0.0f, 0.0f @@ -211,6 +240,66 @@ struct CotabbyInferenceEngine::Impl { return chain; } + // Classifies the whole vocabulary once per model load. Populates the logit-bias masks and the + // whitespace-leading flag used for first-token word continuation. Doing it here keeps the hot + // sampling path free of any per-token tokenizer calls. + void buildTokenMasks() { + nonprintable_bias.clear(); + linebreak_bias.clear(); + starts_new_word.clear(); + if (!vocab) return; + + const int32_t n = llama_vocab_n_tokens(vocab); + starts_new_word.assign(static_cast(n), false); + + char piece[64]; + for (llama_token t = 0; t < n; ++t) { + const bool is_eog = llama_vocab_is_eog(vocab, t); + + // Nonprintable: control (non-EOG), unknown, and unused tokens must never appear as + // text. EOG stays sampleable so the stop check can recognize a natural end of output. + if (!is_eog) { + const enum llama_token_attr attr = llama_vocab_get_attr(vocab, t); + const bool junk_attr = + (attr & (LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_UNUSED)) != 0; + if (llama_vocab_is_control(vocab, t) || junk_attr) { + nonprintable_bias.push_back({ t, -INFINITY }); + } + } + + const int written = llama_token_to_piece(vocab, t, piece, sizeof(piece), 0, false); + if (written <= 0) { + continue; + } + const char first = piece[0]; + if (first == ' ' || first == '\t' || first == '\n' || first == '\r') { + starts_new_word[static_cast(t)] = true; + } + if (!is_eog) { + for (int i = 0; i < written; ++i) { + if (piece[i] == '\n' || piece[i] == '\r') { + linebreak_bias.push_back({ t, -INFINITY }); + break; + } + } + } + } + } + + // Masks every "starts a new word" token (decoded text begins with whitespace) in the logits + // row so the next sampled token must continue the current word. Used for the first token only. + void maskNewWordStarts(int logits_row) { + if (!shared_ctx || !vocab) return; + float* logits = llama_get_logits_ith(shared_ctx, logits_row); + if (!logits) return; + const int32_t n = static_cast(starts_new_word.size()); + for (llama_token t = 0; t < n; ++t) { + if (starts_new_word[static_cast(t)]) { + logits[t] = -INFINITY; + } + } + } + void destroyAllSequences() { std::lock_guard lock(sequences_mutex); for (auto& [id, seq] : sequences) { @@ -440,6 +529,10 @@ EngineStatus CotabbyInferenceEngine::loadModel(const char* path, int gpu_layers, return EngineStatus::error; } + // Precompute the token masks now that the vocab is available; the sampler chains built in + // createSequence read these, and the hot path then needs no per-token tokenizer work. + impl_->buildTokenMasks(); + impl_->startDecoderThread(); return EngineStatus::ok; } @@ -715,6 +808,14 @@ EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id, llama_batch_free(batch); seq->kv_position_count = total_end_position; + // First-token word-continuation constraint: when the caret is mid-word, mask new-word-start + // tokens for this seed only so the completion continues the current word instead of starting + // a new one. The flag clears after this single token. + if (seq->force_word_continuation) { + impl_->maskNewWordStarts(-1); + seq->force_word_continuation = false; + } + // Seed sample: take one token from the prompt's logits row right now, // before any other sequence's decode can overwrite the shared logits // buffer. The seed will be returned by the next sampleNext call as-is @@ -875,6 +976,54 @@ int CotabbyInferenceEngine::getKVPositionCount(int32_t sequence_id) const { return seq ? seq->kv_position_count : 0; } +void CotabbyInferenceEngine::setForceWordContinuation(int32_t sequence_id, bool enabled) { + if (!impl_) return; + SequenceState* seq = impl_->findSequence(sequence_id); + if (seq) { + seq->force_word_continuation = enabled; + } +} + +// --------------------------------------------------------------------------- +// KV state snapshot / restore +// +// Thin wrappers over llama's single-sequence state copy. They serialize with the decoder thread +// via decode_mutex so a snapshot or restore never races an in-flight llama_decode. Callers pair a +// snapshot with the KV position count they observed and pass it back on restore, so this engine's +// own position bookkeeping stays consistent with the restored KV cache. +// --------------------------------------------------------------------------- + +size_t CotabbyInferenceEngine::snapshotSize(int32_t sequence_id) const { + if (!impl_ || !impl_->shared_ctx) return 0; + const SequenceState* seq = impl_->findSequence(sequence_id); + if (!seq) return 0; + return llama_state_seq_get_size(impl_->shared_ctx, seq->seq_id); +} + +size_t CotabbyInferenceEngine::snapshotSequence(int32_t sequence_id, uint8_t* dst, size_t capacity) { + if (!impl_ || !impl_->shared_ctx || !dst) return 0; + SequenceState* seq = impl_->findSequence(sequence_id); + if (!seq) return 0; + std::lock_guard lock(impl_->decode_mutex); + return llama_state_seq_get_data(impl_->shared_ctx, dst, capacity, seq->seq_id); +} + +bool CotabbyInferenceEngine::restoreSequence(int32_t sequence_id, const uint8_t* src, + size_t size, int position_count) { + if (!impl_ || !impl_->shared_ctx || !src) return false; + SequenceState* seq = impl_->findSequence(sequence_id); + if (!seq) return false; + std::lock_guard lock(impl_->decode_mutex); + const size_t read = llama_state_seq_set_data(impl_->shared_ctx, src, size, seq->seq_id); + if (read == 0) return false; + seq->kv_position_count = position_count; + // The restored blob invalidates any seed/pending token captured before; force the caller to + // re-prime via decodePrompt before the next sampleNext. + seq->has_seed_token = false; + seq->has_pending_input = false; + return true; +} + // --------------------------------------------------------------------------- // Cancellation // --------------------------------------------------------------------------- diff --git a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h index 5fdab1c..74d37e3 100644 --- a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h +++ b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -11,6 +12,10 @@ struct SamplingConfig { float min_p; float repetition_penalty; uint32_t seed; + + // When true, tokens that introduce a line break are masked from sampling so single-line + // fields never receive a multi-line completion. Defaults to false to preserve prior behavior. + bool single_line = false; }; struct SWIFT_SELF_CONTAINED SampleResult { @@ -101,6 +106,20 @@ class CotabbyInferenceEngine { bool trimKV(int32_t sequence_id, int keep_positions); int getKVPositionCount(int32_t sequence_id) const; + // Constrains the FIRST token of the next generation on `sequence_id` to continue the current + // word: tokens whose decoded text begins with whitespace are masked for that one token, then + // the constraint clears. Set this before `decodePrompt`, which samples the first (seed) token. + void setForceWordContinuation(int32_t sequence_id, bool enabled); + + // Single-sequence KV state snapshot/restore. `snapshotSize` reports the buffer size needed, + // `snapshotSequence` copies the sequence's KV state into `dst` (returns bytes written, 0 on + // failure), and `restoreSequence` loads a previously captured blob back and sets the KV + // position to `position_count` (returns false on failure). Blobs are model- and context- + // specific; never restore one across a model reload. + size_t snapshotSize(int32_t sequence_id) const; + size_t snapshotSequence(int32_t sequence_id, uint8_t* dst, size_t capacity); + bool restoreSequence(int32_t sequence_id, const uint8_t* src, size_t size, int position_count); + // Cancellation (thread-safe, non-blocking) void cancelSequence(int32_t sequence_id); diff --git a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift index 096eb0f..0e90e04 100644 --- a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift +++ b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift @@ -31,7 +31,8 @@ final class LlamaMiddlewareTests: XCTestCase { top_p: 0.7, min_p: 0.08, repetition_penalty: 1.05, - seed: 0 + seed: 0, + single_line: false ) let seqId = engine.createSequence(config) XCTAssertEqual(seqId, -1) @@ -144,9 +145,10 @@ final class LlamaMiddlewareTests: XCTestCase { XCTAssertFalse(templated.isEmpty) } - // Detokenize first token + // Detokenize a content token (the prompt's last token). Index 0 can be BOS, a control + // token that renders to zero bytes with special=false, so we avoid it here. var buf = [CChar](repeating: 0, count: 64) - let written = engine.detokenize(tokens[0], &buf, Int32(buf.count)) + let written = engine.detokenize(tokens[tokens.count - 1], &buf, Int32(buf.count)) XCTAssertGreaterThan(written, 0) // Create autocomplete sequence @@ -157,7 +159,8 @@ final class LlamaMiddlewareTests: XCTestCase { top_p: 0.7, min_p: 0.08, repetition_penalty: 1.05, - seed: 42 + seed: 42, + single_line: false ) let seqA = engine.createSequence(autoConfig) XCTAssertGreaterThan(seqA, 0) @@ -202,7 +205,8 @@ final class LlamaMiddlewareTests: XCTestCase { top_p: 0.95, min_p: 0.05, repetition_penalty: 1.4, - seed: 0 + seed: 0, + single_line: false ) let seqB = engine.createSequence(summaryConfig) XCTAssertGreaterThan(seqB, 0) @@ -245,12 +249,14 @@ final class LlamaMiddlewareTests: XCTestCase { let configA = SamplingConfig( max_prediction_tokens: 16, temperature: 0, top_k: 0, top_p: 0, min_p: 0, - repetition_penalty: 0, seed: 1 + repetition_penalty: 0, seed: 1, + single_line: false ) let configB = SamplingConfig( max_prediction_tokens: 16, temperature: 0, top_k: 0, top_p: 0, min_p: 0, - repetition_penalty: 0, seed: 2 + repetition_penalty: 0, seed: 2, + single_line: false ) let seqA = engine.createSequence(configA) @@ -310,7 +316,8 @@ final class LlamaMiddlewareTests: XCTestCase { let config = SamplingConfig( max_prediction_tokens: 32, temperature: 0, top_k: 0, top_p: 0, min_p: 0, - repetition_penalty: 0, seed: 0 + repetition_penalty: 0, seed: 0, + single_line: false ) let seq = engine.createSequence(config) @@ -335,4 +342,102 @@ final class LlamaMiddlewareTests: XCTestCase { engine.destroySequence(seq) } + + func testSetForceWordContinuationWithoutModelDoesNotCrash() { + var engine = CotabbyInferenceEngine() + engine.setForceWordContinuation(999, true) + engine.setForceWordContinuation(-1, false) + XCTAssertFalse(engine.isModelLoaded()) + } + + func testSnapshotSizeWithoutModelIsZero() { + let engine = CotabbyInferenceEngine() + XCTAssertEqual(engine.snapshotSize(1), 0) + } + + // With the first-token word-continuation constraint set, the seed token must not start a new + // word, i.e. its decoded text must not begin with whitespace. + func testForceWordContinuationConstrainsFirstToken() throws { + guard let modelPath = ProcessInfo.processInfo.environment["COTABBY_TEST_MODEL_PATH"] else { + try XCTSkipIf(true, "Set COTABBY_TEST_MODEL_PATH to a .gguf file to run this test") + return + } + var engine = CotabbyInferenceEngine() + XCTAssertEqual(engine.loadModel(modelPath, -1, 1024, 256), EngineStatus.ok) + defer { engine.unloadModel() } + + let config = SamplingConfig( + max_prediction_tokens: 8, temperature: 0, + top_k: 0, top_p: 0, min_p: 0, + repetition_penalty: 0, seed: 0, + single_line: false + ) + let seq = engine.createSequence(config) + XCTAssertGreaterThan(seq, 0) + + // Prompt ends mid-word ("writ"); the forced continuation must finish the word. + let prompt = "I am writ" + var tokens = Array(engine.tokenize(prompt, Int32(prompt.utf8.count))) + XCTAssertGreaterThan(tokens.count, 0) + + engine.setForceWordContinuation(seq, true) + XCTAssertEqual(engine.decodePrompt(seq, &tokens, Int32(tokens.count), 0), EngineStatus.ok) + + let result = engine.sampleNext(seq) + if !result.is_eos, let piece = result.piece, result.piece_length > 0 { + let text = String( + bytes: UnsafeBufferPointer( + start: UnsafeRawPointer(piece).assumingMemoryBound(to: UInt8.self), + count: Int(result.piece_length) + ), + encoding: .utf8 + ) ?? "" + if let firstChar = text.first { + XCTAssertFalse( + firstChar.isWhitespace, + "Forced word continuation must not begin the first token with whitespace" + ) + } + } + engine.destroySequence(seq) + } + + // Snapshotting a sequence then restoring it must return the engine's KV position bookkeeping + // to the captured value. + func testSnapshotRestorePreservesPosition() throws { + guard let modelPath = ProcessInfo.processInfo.environment["COTABBY_TEST_MODEL_PATH"] else { + try XCTSkipIf(true, "Set COTABBY_TEST_MODEL_PATH to a .gguf file to run this test") + return + } + var engine = CotabbyInferenceEngine() + XCTAssertEqual(engine.loadModel(modelPath, -1, 1024, 256), EngineStatus.ok) + defer { engine.unloadModel() } + + let config = SamplingConfig( + max_prediction_tokens: 8, temperature: 0, + top_k: 0, top_p: 0, min_p: 0, + repetition_penalty: 0, seed: 0, + single_line: false + ) + let seq = engine.createSequence(config) + let prompt = "The quick brown fox" + var tokens = Array(engine.tokenize(prompt, Int32(prompt.utf8.count))) + XCTAssertEqual(engine.decodePrompt(seq, &tokens, Int32(tokens.count), 0), EngineStatus.ok) + + let position = engine.getKVPositionCount(seq) + XCTAssertGreaterThan(position, 0) + + let size = engine.snapshotSize(seq) + XCTAssertGreaterThan(size, 0) + var buffer = [UInt8](repeating: 0, count: Int(size)) + let written = engine.snapshotSequence(seq, &buffer, size) + XCTAssertGreaterThan(written, 0) + + // Advance past the snapshot point, then restore back to it. + _ = engine.sampleNext(seq) + XCTAssertTrue(engine.restoreSequence(seq, buffer, written, position)) + XCTAssertEqual(engine.getKVPositionCount(seq), position) + + engine.destroySequence(seq) + } } From be64365ebc4e9e8eb6cc0e7f366c56a8067e094d Mon Sep 17 00:00:00 2001 From: Jacob Fu <141651335+FuJacob@users.noreply.github.com> Date: Sun, 31 May 2026 12:16:32 -0700 Subject: [PATCH 2/2] Add per-token logprob to SampleResult for confidence scoring computeLogprob returns the chosen token's log-probability under the raw model distribution (<= 0). It is set on both the batched sample path and the seed token sampled at decodePrompt, and returned with the seed via SampleResult.logprob. This lets the app suppress low-confidence completions. Adds a model-gated test. --- .../CotabbyInferenceEngine.cpp | 31 ++++++++++++++++++- .../include/CotabbyInferenceEngine.h | 3 ++ .../LlamaMiddlewareTests.swift | 30 ++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp index abdcf1c..06930b9 100644 --- a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp +++ b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp @@ -55,6 +55,9 @@ struct SequenceState { // Set by setForceWordContinuation; consumed (and cleared) when the next seed token is sampled. bool force_word_continuation = false; + // Log-probability of the seed token, computed at decodePrompt and returned with the seed. + float seed_logprob = 0.0f; + ~SequenceState() { if (sampler) { llama_sampler_free(sampler); } } @@ -71,7 +74,8 @@ struct SequenceState { has_seed_token(o.has_seed_token), pending_input_token(o.pending_input_token), has_pending_input(o.has_pending_input), - force_word_continuation(o.force_word_continuation) { + force_word_continuation(o.force_word_continuation), + seed_logprob(o.seed_logprob) { o.sampler = nullptr; } SequenceState& operator=(SequenceState&&) = delete; @@ -300,6 +304,28 @@ struct CotabbyInferenceEngine::Impl { } } + // Log-probability of `token` under the raw model distribution at `logits_row`, used as a + // confidence signal. Two O(vocab) passes; only invoked on the autocomplete path. + float computeLogprob(int logits_row, llama_token token) const { + if (!shared_ctx || !vocab) return 0.0f; + const float* logits = llama_get_logits_ith(shared_ctx, logits_row); + if (!logits) return 0.0f; + const int32_t n = llama_vocab_n_tokens(vocab); + if (token < 0 || token >= n) return 0.0f; + float maxLogit = -INFINITY; + for (llama_token t = 0; t < n; ++t) { + if (logits[t] > maxLogit) { maxLogit = logits[t]; } + } + double sumExp = 0.0; + for (llama_token t = 0; t < n; ++t) { + sumExp += std::exp(static_cast(logits[t] - maxLogit)); + } + if (!(sumExp > 0.0)) return 0.0f; + return static_cast( + static_cast(logits[token] - maxLogit) - std::log(sumExp) + ); + } + void destroyAllSequences() { std::lock_guard lock(sequences_mutex); for (auto& [id, seq] : sequences) { @@ -428,6 +454,7 @@ struct CotabbyInferenceEngine::Impl { r.token = next; r.piece = piece.c_str(); r.piece_length = static_cast(piece.size()); + r.logprob = computeLogprob(i, next); } } @@ -823,6 +850,7 @@ EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id, llama_token seed = llama_sampler_sample(seq->sampler, impl_->shared_ctx, -1); llama_sampler_accept(seq->sampler, seed); seq->seed_token = seed; + seq->seed_logprob = impl_->computeLogprob(-1, seed); seq->has_seed_token = true; seq->has_pending_input = false; @@ -893,6 +921,7 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) { result.token = next; result.piece = seq->last_piece.c_str(); result.piece_length = static_cast(seq->last_piece.size()); + result.logprob = seq->seed_logprob; return result; } diff --git a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h index 74d37e3..151d689 100644 --- a/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h +++ b/Sources/CotabbyInferenceEngine/include/CotabbyInferenceEngine.h @@ -24,6 +24,9 @@ struct SWIFT_SELF_CONTAINED SampleResult { int piece_length; bool is_eos; bool was_cancelled; + // Log-probability of the chosen token under the raw model distribution (<= 0). Used as a + // confidence signal; 0 for the EOS/cancelled cases where it carries no meaning. + float logprob; }; enum class EngineStatus : int { diff --git a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift index 0e90e04..a451e2b 100644 --- a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift +++ b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift @@ -440,4 +440,34 @@ final class LlamaMiddlewareTests: XCTestCase { engine.destroySequence(seq) } + + // A sampled (non-EOS) token must carry a finite log-probability that is <= 0, so the app can use + // it as a confidence signal. + func testSampleNextReportsFiniteLogprob() throws { + guard let modelPath = ProcessInfo.processInfo.environment["COTABBY_TEST_MODEL_PATH"] else { + try XCTSkipIf(true, "Set COTABBY_TEST_MODEL_PATH to a .gguf file to run this test") + return + } + var engine = CotabbyInferenceEngine() + XCTAssertEqual(engine.loadModel(modelPath, -1, 1024, 256), EngineStatus.ok) + defer { engine.unloadModel() } + + let config = SamplingConfig( + max_prediction_tokens: 4, temperature: 0, + top_k: 0, top_p: 0, min_p: 0, + repetition_penalty: 0, seed: 0, + single_line: false + ) + let seq = engine.createSequence(config) + let prompt = "The quick brown fox" + var tokens = Array(engine.tokenize(prompt, Int32(prompt.utf8.count))) + XCTAssertEqual(engine.decodePrompt(seq, &tokens, Int32(tokens.count), 0), EngineStatus.ok) + + let result = engine.sampleNext(seq) + if !result.is_eos { + XCTAssertTrue(result.logprob.isFinite, "logprob must be finite") + XCTAssertLessThanOrEqual(result.logprob, 0.0001, "a log-probability must be <= 0") + } + engine.destroySequence(seq) + } }