From 43ea120ade9de783a94b35c27147959dcd618648 Mon Sep 17 00:00:00 2001 From: Jacob Fu <141651335+FuJacob@users.noreply.github.com> Date: Thu, 28 May 2026 03:37:23 -0700 Subject: [PATCH] Refactor engine to one shared context with batched decode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the per-sequence llama_context architecture with a single shared context (n_seq_max = MAX_SEQUENCES) and a dedicated decoder thread that coalesces sample-step requests from multiple sequences into one llama_decode call. Public C++ API (CotabbyInferenceEngine.h) is unchanged; Cotabby's Swift code does not need to be modified. Why --- Phase 0 spike (see PR #3) showed that on M-series Metal, batched decode delivers 1.43x aggregate throughput at N=2 and up to 2.35x at N=4 vs the current "separate llama_context per sequence" design. The win comes from fusing matmul weight reads across sequences in a single llama_decode call: per-token decode is memory-bound on Apple Silicon, so a single decode that serves two sequences reuses the same weight read. The "Metal command queue serializes everything" pessimism does not survive empirically. Design ------ - Impl owns one llama_context with n_ctx = configured_ctx * MAX_SEQUENCES and n_seq_max = MAX_SEQUENCES. Each SequenceState carries a llama_seq_id slot (0..MAX_SEQUENCES-1) used to tag tokens in the shared KV cache. - Decoder thread loop: wait for at least one pending request, wait an additional BATCH_WINDOW_MICROS (200 µs by default) for siblings to pile in, then build one llama_batch carrying all pending tokens with their respective seq_ids, llama_decode once, sample each sequence's next token using its own sampler chain at its assigned batch index, and resolve every request's promise. - sampleNext fast path: deliver the seed token sampled at decodePrompt time. This avoids the decoder round-trip for the very first sample after a prompt, where there is no input token to feedback-decode. - sampleNext steady-state path: queue a PendingRequest (input token = previously-sampled token, position = current KV count, sampler = this sequence's chain) and wait on a std::promise resolved by the decoder thread. - decodePrompt holds decode_mutex for the prompt's chunk decode and takes the seed sample inline while the prompt's logits are still resident in the shared context. - trimKV holds decode_mutex, calls llama_memory_seq_rm for this sequence's seq_id, and invalidates any pending seed/input so the caller has to re-prime via decodePrompt before the next sampleNext. The 200 µs window is the throughput knob. Multi-sequence workloads naturally fall into lockstep because each sequence resubmits as soon as its sample returns, so successive requests usually arrive within the window without any caller-side coordination. Single-sequence callers pay one window's worth of latency per token (~2% of a ~10 ms decode); acceptable. Tunable later via a setter if needed. Cancellation ------------ - Existing one-way atomic flag preserved. - Checked at sampleNext entry (returns immediately) and again in processBatch after llama_decode but before sampling (skips wasted sample work, returns was_cancelled=true). The decode slot for a cancelled token is still consumed, which is fine — the slot is cheap; the win is not running the sampler. Tests ----- Added two integration tests gated on COTABBY_TEST_MODEL_PATH: - testInterleavedMultiSequenceSampling: alternates sampleNext between two sequences with greedy sampling and identical prompts, asserts identical output (validates the seed-token / feedback-decode handoff and per-sequence sampler isolation in the shared context). - testCancellationStopsSamplingPromptly: verifies sampleNext after cancelSequence returns was_cancelled=true without model work. Existing testEndToEndWithModel passes unchanged. Follow-ups ---------- - Bench scenario c_engine_threaded that exercises the full engine via its public API from two threads, for end-to-end throughput validation (Phase 0 numbers above are at the raw llama.cpp level). - README update: the "no shared decode mutex, no contention" claim is no longer accurate. The new design has a single decode_mutex serializing access to one llama_context. The contention is productive — it enables batching — but the README should reflect the new model. --- .../CotabbyInferenceEngine.cpp | 501 ++++++++++++++---- .../LlamaMiddlewareTests.swift | 112 ++++ 2 files changed, 513 insertions(+), 100 deletions(-) diff --git a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp index 251ebd1..d1d77c2 100644 --- a/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp +++ b/Sources/CotabbyInferenceEngine/CotabbyInferenceEngine.cpp @@ -1,12 +1,17 @@ #include "CotabbyInferenceEngine.h" +#include #include +#include +#include #include +#include #include #include #include #include #include +#include #include #include @@ -14,28 +19,54 @@ static void silenced_log_callback(ggml_log_level, const char*, void*) {} // --------------------------------------------------------------------------- -// Per-sequence state (one llama_context + sampler per sequence) +// Per-sequence state +// +// Phase 1 architecture: all sequences share a single `llama_context` allocated +// in `Impl`. Each sequence owns its own sampler chain, KV-cache position +// counter, cancellation flag, and detokenization buffer. The `seq_id` is the +// internal `llama_seq_id` slot (0..MAX_SEQUENCES-1) used to tag this +// sequence's tokens in the shared KV cache. +// +// `seed_token` / `has_seed_token` carries the first sample produced by +// `decodePrompt`. We sample it right after the prompt's final decode while +// that sequence's logits are still live in the shared context; the next +// `llama_decode` for a different sequence would overwrite them. +// +// `pending_input_token` / `has_pending_input` carries the token that +// `sampleNext` returned and must be feedback-decoded on the next call so the +// shared context produces fresh logits at the new position. // --------------------------------------------------------------------------- struct SequenceState { - llama_context* context = nullptr; + llama_seq_id seq_id = -1; llama_sampler* sampler = nullptr; SamplingConfig config{}; int kv_position_count = 0; std::atomic cancelled{false}; std::string last_piece; + llama_token seed_token = 0; + bool has_seed_token = false; + + llama_token pending_input_token = 0; + bool has_pending_input = false; + ~SequenceState() { if (sampler) { llama_sampler_free(sampler); } - if (context) { llama_free(context); } } SequenceState() = default; SequenceState(SequenceState&& o) noexcept - : context(o.context), sampler(o.sampler), config(o.config), + : seq_id(o.seq_id), + sampler(o.sampler), + config(o.config), kv_position_count(o.kv_position_count), - cancelled(o.cancelled.load()), last_piece(std::move(o.last_piece)) { - o.context = nullptr; + cancelled(o.cancelled.load()), + last_piece(std::move(o.last_piece)), + seed_token(o.seed_token), + has_seed_token(o.has_seed_token), + pending_input_token(o.pending_input_token), + has_pending_input(o.has_pending_input) { o.sampler = nullptr; } SequenceState& operator=(SequenceState&&) = delete; @@ -43,31 +74,84 @@ struct SequenceState { SequenceState& operator=(const SequenceState&) = delete; }; +// --------------------------------------------------------------------------- +// Pending decode + sample request handed to the decoder thread. +// +// Holds raw pointers into a SequenceState entry. SequenceState entries live in +// a node-based unordered_map (stable addresses across inserts), and the +// public contract forbids destroying a sequence with sampleNext in flight, so +// the pointers stay valid for the request's lifetime. +// --------------------------------------------------------------------------- + +struct PendingRequest { + llama_seq_id seq_id = -1; + llama_token token = 0; + int position = 0; + llama_sampler* sampler = nullptr; + std::atomic* cancelled_ptr = nullptr; + std::string* piece_buffer = nullptr; + SampleResult* result_out = nullptr; + std::promise done; +}; + // --------------------------------------------------------------------------- // PIMPL // --------------------------------------------------------------------------- struct CotabbyInferenceEngine::Impl { static constexpr int MAX_SEQUENCES = 4; - static constexpr llama_seq_id SEQ_ID = 0; + + // Microseconds the decoder thread waits after the first request arrives + // before flushing. This is the knob that lets multi-sequence callers pile + // up tokens for a batched `llama_decode`. Too short → no batching; too + // long → single-sequence callers feel extra latency per token. 200µs was + // chosen as a starting point and should be tuned via the bench. + static constexpr int BATCH_WINDOW_MICROS = 200; llama_model* model = nullptr; const llama_vocab* vocab = nullptr; bool backend_initialized = false; std::string model_path; + llama_context* shared_ctx = nullptr; int context_window_tokens = 0; int batch_size = 0; int thread_count = 0; int gpu_layer_count = 0; + // Public-facing sequence map (external int32_t IDs → state) and the + // internal `llama_seq_id` slot allocator. mutable std::mutex sequences_mutex; std::unordered_map sequences; - int32_t next_sequence_id = 1; + int32_t next_external_id = 1; + bool seq_slot_in_use[MAX_SEQUENCES] = {false}; + + // Decoder thread. Owns all `llama_decode` calls on `shared_ctx` after + // model load, including both sample-step batches (built from + // `sampleNext` requests) and prompt-decode chunks (forwarded from + // `decodePrompt` via the same queue path). + std::mutex decode_mutex; + std::condition_variable request_cv; + std::vector pending; + std::thread decoder_thread; + bool decoder_should_stop = false; + bool decoder_running = false; + + int allocateSeqSlot() { + for (int i = 0; i < MAX_SEQUENCES; ++i) { + if (!seq_slot_in_use[i]) { + seq_slot_in_use[i] = true; + return i; + } + } + return -1; + } - // ----------------------------------------------------------------------- - // Helpers - // ----------------------------------------------------------------------- + void releaseSeqSlot(int slot) { + if (slot >= 0 && slot < MAX_SEQUENCES) { + seq_slot_in_use[slot] = false; + } + } SequenceState* findSequence(int32_t id) { std::lock_guard lock(sequences_mutex); @@ -86,18 +170,13 @@ struct CotabbyInferenceEngine::Impl { llama_sampler* chain = llama_sampler_chain_init(params); if (!chain) return nullptr; - // 1. Repetition penalty if (cfg.repetition_penalty > 1.0f) { auto* pen = llama_sampler_init_penalties( - 64, - cfg.repetition_penalty, - 0.0f, - 0.0f + 64, cfg.repetition_penalty, 0.0f, 0.0f ); if (pen) llama_sampler_chain_add(chain, pen); } - // 2a. Stochastic path if (cfg.temperature > 0.0f) { auto* temp = llama_sampler_init_temp(cfg.temperature); if (temp) llama_sampler_chain_add(chain, temp); @@ -124,8 +203,6 @@ struct CotabbyInferenceEngine::Impl { } auto* dist = llama_sampler_init_dist(resolved_seed); if (dist) llama_sampler_chain_add(chain, dist); - - // 2b. Greedy path } else { auto* greedy = llama_sampler_init_greedy(); if (greedy) llama_sampler_chain_add(chain, greedy); @@ -136,8 +213,141 @@ struct CotabbyInferenceEngine::Impl { void destroyAllSequences() { std::lock_guard lock(sequences_mutex); + for (auto& [id, seq] : sequences) { + releaseSeqSlot(seq.seq_id); + } sequences.clear(); } + + void startDecoderThread() { + if (decoder_running) return; + decoder_should_stop = false; + decoder_running = true; + decoder_thread = std::thread([this]() { decoderRun(); }); + } + + void stopDecoderThread() { + if (!decoder_running) return; + { + std::lock_guard lock(decode_mutex); + decoder_should_stop = true; + } + request_cv.notify_all(); + if (decoder_thread.joinable()) { + decoder_thread.join(); + } + decoder_running = false; + } + + // Decoder loop: collect pending sample-step requests, batch them into one + // llama_decode, sample each sequence's next token using its own sampler, + // and resolve every request's promise so the caller threads can return. + void decoderRun() { + while (true) { + std::vector batch; + { + std::unique_lock lock(decode_mutex); + request_cv.wait(lock, [&] { + return decoder_should_stop || !pending.empty(); + }); + if (decoder_should_stop && pending.empty()) return; + + // Brief flush window: when only one request has arrived, + // wait a short period for siblings to pile in so the next + // llama_decode can batch them together. Multi-sequence + // workloads naturally fall into lockstep here because each + // sequence resubmits as soon as its previous sample returns. + request_cv.wait_for( + lock, + std::chrono::microseconds(BATCH_WINDOW_MICROS), + [&] { return decoder_should_stop; } + ); + + if (pending.empty()) { + if (decoder_should_stop) return; + continue; + } + + batch = std::move(pending); + pending.clear(); + + // Process while still holding decode_mutex so prompt-decode + // calls in `decodePrompt` cannot race with the sample-step + // llama_decode below. llama_decode + sampling for a small + // batch is ~10ms; callers blocked on staging will simply + // queue for the next cycle. + processBatch(batch); + } + } + } + + void processBatch(std::vector& reqs) { + if (reqs.empty() || !shared_ctx) return; + + llama_batch batch = llama_batch_init( + static_cast(reqs.size() + 4), 0, 1 + ); + batch.n_tokens = static_cast(reqs.size()); + for (int i = 0; i < static_cast(reqs.size()); ++i) { + batch.token[i] = reqs[i].token; + batch.pos[i] = static_cast(reqs[i].position); + batch.n_seq_id[i] = 1; + if (batch.seq_id && batch.seq_id[i]) { + batch.seq_id[i][0] = reqs[i].seq_id; + } + batch.logits[i] = 1; + } + + int status = llama_decode(shared_ctx, batch); + + for (int i = 0; i < static_cast(reqs.size()); ++i) { + PendingRequest& req = reqs[i]; + SampleResult r{}; + r.token = 0; + r.piece = nullptr; + r.piece_length = 0; + r.is_eos = false; + r.was_cancelled = false; + + if (status != 0) { + r.is_eos = true; + } else if (req.cancelled_ptr && + req.cancelled_ptr->load(std::memory_order_acquire)) { + r.was_cancelled = true; + } else { + llama_token next = llama_sampler_sample( + req.sampler, shared_ctx, i + ); + if (next == llama_vocab_eos(vocab) || + llama_vocab_is_eog(vocab, next)) { + r.token = next; + r.is_eos = true; + } else { + llama_sampler_accept(req.sampler, next); + + std::string& piece = *req.piece_buffer; + piece.resize(64); + while (true) { + int written = llama_token_to_piece( + vocab, next, piece.data(), + static_cast(piece.size()), 0, false + ); + if (written >= 0) { piece.resize(written); break; } + piece.resize(static_cast(-written) + 1); + } + + r.token = next; + r.piece = piece.c_str(); + r.piece_length = static_cast(piece.size()); + } + } + + *req.result_out = r; + req.done.set_value(); + } + + llama_batch_free(batch); + } }; // --------------------------------------------------------------------------- @@ -167,12 +377,10 @@ EngineStatus CotabbyInferenceEngine::loadModel(const char* path, int gpu_layers, int batch_size) { if (!impl_ || !path) return EngineStatus::error; - // Idempotent for same path if (impl_->model && impl_->model_path == path) { return EngineStatus::ok; } - // Different model already loaded — tear down first if (impl_->model) { unloadModel(); } @@ -208,13 +416,45 @@ EngineStatus CotabbyInferenceEngine::loadModel(const char* path, int gpu_layers, std::max(1u, std::thread::hardware_concurrency()) ); + // Shared context sized to hold MAX_SEQUENCES sequences each with up to + // `context_window_tokens` KV slots. llama.cpp's `n_ctx` is the total slot + // budget across all sequences in a context, so we multiply to give each + // sequence the configured window without sequences stealing slots from + // each other. + auto ctx_params = llama_context_default_params(); + ctx_params.n_ctx = static_cast( + context_window_tokens * Impl::MAX_SEQUENCES + ); + ctx_params.n_batch = static_cast(batch_size); + ctx_params.n_ubatch = static_cast(batch_size); + ctx_params.n_seq_max = static_cast(Impl::MAX_SEQUENCES); + ctx_params.n_threads = static_cast(impl_->thread_count); + ctx_params.n_threads_batch = static_cast(impl_->thread_count); + ctx_params.offload_kqv = true; + + impl_->shared_ctx = llama_init_from_model(impl_->model, ctx_params); + if (!impl_->shared_ctx) { + llama_model_free(impl_->model); + impl_->model = nullptr; + impl_->vocab = nullptr; + return EngineStatus::error; + } + + impl_->startDecoderThread(); return EngineStatus::ok; } void CotabbyInferenceEngine::unloadModel() { if (!impl_) return; + + impl_->stopDecoderThread(); impl_->destroyAllSequences(); + if (impl_->shared_ctx) { + llama_free(impl_->shared_ctx); + impl_->shared_ctx = nullptr; + } + if (impl_->model) { llama_model_free(impl_->model); impl_->model = nullptr; @@ -229,7 +469,7 @@ void CotabbyInferenceEngine::unloadModel() { } bool CotabbyInferenceEngine::isModelLoaded() const { - return impl_ && impl_->model != nullptr; + return impl_ && impl_->model != nullptr && impl_->shared_ctx != nullptr; } // --------------------------------------------------------------------------- @@ -237,49 +477,60 @@ bool CotabbyInferenceEngine::isModelLoaded() const { // --------------------------------------------------------------------------- int32_t CotabbyInferenceEngine::createSequence(SamplingConfig config) { - if (!impl_->model) return -1; + if (!impl_->model || !impl_->shared_ctx) return -1; - { - std::lock_guard lock(impl_->sequences_mutex); - if (static_cast(impl_->sequences.size()) >= Impl::MAX_SEQUENCES) { - return -1; - } - } - - // Build context - auto ctx_params = llama_context_default_params(); - ctx_params.n_ctx = static_cast(impl_->context_window_tokens); - ctx_params.n_batch = static_cast(impl_->batch_size); - ctx_params.n_ubatch = static_cast(impl_->batch_size); - ctx_params.n_seq_max = 1; - ctx_params.n_threads = static_cast(impl_->thread_count); - ctx_params.n_threads_batch = static_cast(impl_->thread_count); - ctx_params.offload_kqv = true; - - llama_context* ctx = llama_init_from_model(impl_->model, ctx_params); - if (!ctx) return -1; + std::lock_guard lock(impl_->sequences_mutex); + int slot = impl_->allocateSeqSlot(); + if (slot < 0) return -1; - // Build sampler llama_sampler* sampler = impl_->buildSampler(config); if (!sampler) { - llama_free(ctx); + impl_->releaseSeqSlot(slot); return -1; } SequenceState state; - state.context = ctx; + state.seq_id = static_cast(slot); state.sampler = sampler; state.config = config; - std::lock_guard lock(impl_->sequences_mutex); - int32_t id = impl_->next_sequence_id++; + int32_t id = impl_->next_external_id++; impl_->sequences.emplace(id, std::move(state)); return id; } void CotabbyInferenceEngine::destroySequence(int32_t sequence_id) { - std::lock_guard lock(impl_->sequences_mutex); - impl_->sequences.erase(sequence_id); + if (!impl_) return; + + // Look up the internal slot once. Caller's contract is to not destroy a + // sequence with sampleNext in flight, so the entry is stable for the + // duration of this call. + llama_seq_id slot_to_wipe = -1; + { + std::lock_guard seq_lock(impl_->sequences_mutex); + auto it = impl_->sequences.find(sequence_id); + if (it == impl_->sequences.end()) return; + slot_to_wipe = it->second.seq_id; + } + + // Wipe this sequence's KV slots in the shared context before releasing + // the slot, otherwise stale positions linger and reusing the slot later + // would mix old tokens with new ones. Hold decode_mutex so the wipe does + // not race with the decoder thread's llama_decode call. + if (impl_->shared_ctx && slot_to_wipe >= 0) { + std::lock_guard decode_lock(impl_->decode_mutex); + llama_memory_t memory = llama_get_memory(impl_->shared_ctx); + if (memory) { + llama_memory_seq_rm(memory, slot_to_wipe, 0, -1); + } + } + + std::lock_guard seq_lock(impl_->sequences_mutex); + auto it = impl_->sequences.find(sequence_id); + if (it != impl_->sequences.end()) { + impl_->releaseSeqSlot(it->second.seq_id); + impl_->sequences.erase(it); + } } // --------------------------------------------------------------------------- @@ -314,7 +565,6 @@ std::vector CotabbyInferenceEngine::tokenize(const char* text, if (n == 0) { return {}; } - // n < 0 means buffer too small, -n is the required capacity capacity = std::max(capacity * 2, -n); } } @@ -337,13 +587,22 @@ int CotabbyInferenceEngine::detokenize(int32_t token, char* buffer, // --------------------------------------------------------------------------- // Prompt decoding +// +// Prompt decode runs synchronously on the calling thread, but takes +// `decode_mutex` so it serializes with the decoder thread's sample-step +// llama_decode calls. After the prompt's final decode succeeds, we +// immediately sample one "seed" token using this sequence's sampler while +// the prompt's logits are still live in the shared context. That seed is +// handed back via the very next `sampleNext` call without any further +// decode work — subsequent calls feedback-decode this seed (then each +// previous sample) to produce fresh logits for the next sample. // --------------------------------------------------------------------------- EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id, const int32_t* tokens, int token_count, int start_position) { - if (!impl_->model) return EngineStatus::not_loaded; + if (!impl_->model || !impl_->shared_ctx) return EngineStatus::not_loaded; if (!tokens || token_count <= 0) return EngineStatus::ok; SequenceState* seq = impl_->findSequence(sequence_id); @@ -353,6 +612,8 @@ EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id, return EngineStatus::cancelled; } + std::unique_lock lock(impl_->decode_mutex); + int batch_cap = impl_->batch_size; llama_batch batch = llama_batch_init(static_cast(batch_cap), 0, 1); @@ -377,14 +638,13 @@ EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id, batch.pos[i] = static_cast(start_position + token_index); batch.n_seq_id[i] = 1; if (batch.seq_id && batch.seq_id[i]) { - batch.seq_id[i][0] = Impl::SEQ_ID; + batch.seq_id[i][0] = seq->seq_id; } - // Logits only for the very last token of the entire prompt bool is_last = (chunk_end == end && i == chunk_size - 1); batch.logits[i] = is_last ? 1 : 0; } - if (llama_decode(seq->context, batch) != 0) { + if (llama_decode(impl_->shared_ctx, batch) != 0) { llama_batch_free(batch); return EngineStatus::error; } @@ -394,11 +654,26 @@ EngineStatus CotabbyInferenceEngine::decodePrompt(int32_t sequence_id, llama_batch_free(batch); seq->kv_position_count = total_end_position; + + // Seed sample: take one token from the prompt's logits row right now, + // before any other sequence's decode can overwrite the shared logits + // buffer. The seed will be returned by the next sampleNext call as-is + // and feedback-decoded by the call after that. + llama_token seed = llama_sampler_sample(seq->sampler, impl_->shared_ctx, -1); + llama_sampler_accept(seq->sampler, seed); + seq->seed_token = seed; + seq->has_seed_token = true; + seq->has_pending_input = false; + return EngineStatus::ok; } // --------------------------------------------------------------------------- // Sampling +// +// First call after decodePrompt: return the seed token directly (no decode +// queued). Subsequent calls: queue the previously sampled token for feedback +// decode via the decoder thread, then return its sampled result. // --------------------------------------------------------------------------- SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) { @@ -409,7 +684,7 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) { result.is_eos = false; result.was_cancelled = false; - if (!impl_->model || !impl_->vocab) { + if (!impl_->model || !impl_->vocab || !impl_->shared_ctx) { result.is_eos = true; return result; } @@ -425,63 +700,79 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) { return result; } - // Sample - llama_token next_token = llama_sampler_sample(seq->sampler, seq->context, -1); + // Fast path: deliver the seed token sampled at decodePrompt time. No + // shared-context work needed because the seed was already computed under + // decode_mutex while the prompt's logits were resident. + if (seq->has_seed_token) { + llama_token next = seq->seed_token; + seq->has_seed_token = false; + + if (next == llama_vocab_eos(impl_->vocab) || + llama_vocab_is_eog(impl_->vocab, next)) { + result.token = next; + result.is_eos = true; + return result; + } - // Check EOS / EOG - if (next_token == llama_vocab_eos(impl_->vocab) || - llama_vocab_is_eog(impl_->vocab, next_token)) { - result.token = next_token; - result.is_eos = true; + seq->last_piece.resize(64); + while (true) { + int written = llama_token_to_piece( + impl_->vocab, next, seq->last_piece.data(), + static_cast(seq->last_piece.size()), 0, false + ); + if (written >= 0) { seq->last_piece.resize(written); break; } + seq->last_piece.resize(static_cast(-written) + 1); + } + + // The seed has not yet been added to KV. Queue it as the next + // feedback-decode input so the call after this one has fresh logits. + seq->pending_input_token = next; + seq->has_pending_input = true; + + result.token = next; + result.piece = seq->last_piece.c_str(); + result.piece_length = static_cast(seq->last_piece.size()); return result; } - // Detokenize into the sequence's persistent buffer - seq->last_piece.resize(64); - while (true) { - int written = llama_token_to_piece( - impl_->vocab, - next_token, - seq->last_piece.data(), - static_cast(seq->last_piece.size()), - 0, - false - ); - if (written >= 0) { - seq->last_piece.resize(written); - break; - } - seq->last_piece.resize(static_cast(-written) + 1); + if (!seq->has_pending_input) { + // Nothing to decode and no seed — caller forgot to decodePrompt or + // trimmed the KV down to nothing without re-priming. + result.is_eos = true; + return result; } - llama_sampler_accept(seq->sampler, next_token); + // Steady-state path: queue the previously sampled token (or the seed, + // if this is the call right after the seed was delivered) for feedback + // decode via the decoder thread, which batches it with any other + // sequences that happen to be sampling at the same time. + PendingRequest req; + req.seq_id = seq->seq_id; + req.token = seq->pending_input_token; + req.position = seq->kv_position_count; + req.sampler = seq->sampler; + req.cancelled_ptr = &seq->cancelled; + req.piece_buffer = &seq->last_piece; + req.result_out = &result; + auto done_future = req.done.get_future(); - // Decode the sampled token to advance KV cache - llama_batch batch = llama_batch_init(1, 0, 1); - batch.n_tokens = 1; - batch.token[0] = next_token; - batch.pos[0] = static_cast(seq->kv_position_count); - batch.n_seq_id[0] = 1; - if (batch.seq_id && batch.seq_id[0]) { - batch.seq_id[0][0] = Impl::SEQ_ID; + { + std::lock_guard lock(impl_->decode_mutex); + impl_->pending.push_back(std::move(req)); + impl_->request_cv.notify_one(); } - batch.logits[0] = 1; - int decode_status = llama_decode(seq->context, batch); - llama_batch_free(batch); + done_future.wait(); - if (decode_status != 0) { - result.is_eos = true; + if (result.is_eos || result.was_cancelled) { return result; } + // Feedback decode advanced KV by one position; record the just-sampled + // token as input for the next call. seq->kv_position_count++; - - result.token = next_token; - result.piece = seq->last_piece.c_str(); - result.piece_length = static_cast(seq->last_piece.size()); - result.is_eos = false; - result.was_cancelled = false; + seq->pending_input_token = result.token; + seq->has_pending_input = true; return result; } @@ -490,21 +781,31 @@ SampleResult CotabbyInferenceEngine::sampleNext(int32_t sequence_id) { // --------------------------------------------------------------------------- bool CotabbyInferenceEngine::trimKV(int32_t sequence_id, int keep_positions) { + if (!impl_->shared_ctx) return false; SequenceState* seq = impl_->findSequence(sequence_id); if (!seq) return false; - llama_memory_t memory = llama_get_memory(seq->context); + llama_memory_t memory = llama_get_memory(impl_->shared_ctx); if (!memory) return false; + // Serialize with the decoder thread; we don't want to remove KV slots + // mid-batch. + std::lock_guard lock(impl_->decode_mutex); + bool ok = llama_memory_seq_rm( memory, - Impl::SEQ_ID, + seq->seq_id, static_cast(keep_positions), -1 ); if (ok) { seq->kv_position_count = keep_positions; + // Any seed/pending input is now stale (it would feedback-decode into + // a trimmed-away position). Caller must call decodePrompt to re-seed + // before the next sampleNext. + seq->has_seed_token = false; + seq->has_pending_input = false; } return ok; } diff --git a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift index 064c1d2..0826289 100644 --- a/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift +++ b/Tests/CotabbyInferenceTests/LlamaMiddlewareTests.swift @@ -173,4 +173,116 @@ final class LlamaMiddlewareTests: XCTestCase { engine.unloadModel() XCTAssertFalse(engine.isModelLoaded()) } + + // Multi-sequence sequential test. Drives the new shared-context decoder + // thread through two interleaved sampling loops to verify the seed-token + // / feedback-decode handoff produces valid tokens for both sequences + // when their sampleNext calls alternate. + func testInterleavedMultiSequenceSampling() throws { + guard let modelPath = ProcessInfo.processInfo.environment["COTABBY_TEST_MODEL_PATH"] else { + try XCTSkipIf(true, "Set COTABBY_TEST_MODEL_PATH to a .gguf file to run this test") + return + } + + var engine = CotabbyInferenceEngine() + XCTAssertEqual(engine.loadModel(modelPath, -1, 1024, 256), EngineStatus.ok) + defer { engine.unloadModel() } + + let prompt = "The quick brown fox jumps over the lazy dog." + let tokens = engine.tokenize(prompt, Int32(prompt.utf8.count)) + XCTAssertGreaterThan(tokens.size(), 0) + + let configA = SamplingConfig( + max_prediction_tokens: 16, temperature: 0, + top_k: 0, top_p: 0, min_p: 0, + repetition_penalty: 0, seed: 1 + ) + let configB = SamplingConfig( + max_prediction_tokens: 16, temperature: 0, + top_k: 0, top_p: 0, min_p: 0, + repetition_penalty: 0, seed: 2 + ) + + let seqA = engine.createSequence(configA) + let seqB = engine.createSequence(configB) + XCTAssertGreaterThan(seqA, 0) + XCTAssertGreaterThan(seqB, 0) + + var tokenArr = Array(tokens) + XCTAssertEqual( + engine.decodePrompt(seqA, &tokenArr, Int32(tokenArr.count), 0), + EngineStatus.ok + ) + XCTAssertEqual( + engine.decodePrompt(seqB, &tokenArr, Int32(tokenArr.count), 0), + EngineStatus.ok + ) + + // Alternate sampleNext between the two sequences. With greedy + // sampling and identical prompts, the first sampled tokens for both + // sequences should be identical (different samplers reading the + // same logits row at separate decodePrompt times). + var sampledA: [Int32] = [] + var sampledB: [Int32] = [] + for _ in 0..<8 { + let rA = engine.sampleNext(seqA) + let rB = engine.sampleNext(seqB) + XCTAssertFalse(rA.was_cancelled) + XCTAssertFalse(rB.was_cancelled) + if rA.is_eos || rB.is_eos { break } + sampledA.append(rA.token) + sampledB.append(rB.token) + } + XCTAssertEqual(sampledA.count, sampledB.count) + XCTAssertGreaterThan(sampledA.count, 0) + XCTAssertEqual(sampledA, sampledB, + "Greedy sampling with identical prompts should match across sequences") + + engine.destroySequence(seqA) + engine.destroySequence(seqB) + } + + // Cancellation regression: setting cancelled on a sequence mid-loop must + // cause subsequent sampleNext calls to return was_cancelled=true so + // callers can break out without waiting for the full prediction budget. + func testCancellationStopsSamplingPromptly() throws { + guard let modelPath = ProcessInfo.processInfo.environment["COTABBY_TEST_MODEL_PATH"] else { + try XCTSkipIf(true, "Set COTABBY_TEST_MODEL_PATH to a .gguf file to run this test") + return + } + + var engine = CotabbyInferenceEngine() + XCTAssertEqual(engine.loadModel(modelPath, -1, 1024, 256), EngineStatus.ok) + defer { engine.unloadModel() } + + let prompt = "Hello" + let tokens = engine.tokenize(prompt, Int32(prompt.utf8.count)) + let config = SamplingConfig( + max_prediction_tokens: 32, temperature: 0, + top_k: 0, top_p: 0, min_p: 0, + repetition_penalty: 0, seed: 0 + ) + + let seq = engine.createSequence(config) + var tokenArr = Array(tokens) + XCTAssertEqual( + engine.decodePrompt(seq, &tokenArr, Int32(tokenArr.count), 0), + EngineStatus.ok + ) + + // Sample a couple of tokens first. + for _ in 0..<2 { + let r = engine.sampleNext(seq) + XCTAssertFalse(r.was_cancelled) + } + + // Cancel. The next sampleNext should return was_cancelled=true + // without doing any further model work. + engine.cancelSequence(seq) + let cancelled = engine.sampleNext(seq) + XCTAssertTrue(cancelled.was_cancelled, + "sampleNext after cancelSequence must return was_cancelled=true") + + engine.destroySequence(seq) + } }