Phase 1: refactor engine to one shared context with batched decode#5
Merged
Conversation
Replaces the per-sequence llama_context architecture with a single shared context (n_seq_max = MAX_SEQUENCES) and a dedicated decoder thread that coalesces sample-step requests from multiple sequences into one llama_decode call. Public C++ API (CotabbyInferenceEngine.h) is unchanged; Cotabby's Swift code does not need to be modified. Why --- Phase 0 spike (see PR #3) showed that on M-series Metal, batched decode delivers 1.43x aggregate throughput at N=2 and up to 2.35x at N=4 vs the current "separate llama_context per sequence" design. The win comes from fusing matmul weight reads across sequences in a single llama_decode call: per-token decode is memory-bound on Apple Silicon, so a single decode that serves two sequences reuses the same weight read. The "Metal command queue serializes everything" pessimism does not survive empirically. Design ------ - Impl owns one llama_context with n_ctx = configured_ctx * MAX_SEQUENCES and n_seq_max = MAX_SEQUENCES. Each SequenceState carries a llama_seq_id slot (0..MAX_SEQUENCES-1) used to tag tokens in the shared KV cache. - Decoder thread loop: wait for at least one pending request, wait an additional BATCH_WINDOW_MICROS (200 µs by default) for siblings to pile in, then build one llama_batch carrying all pending tokens with their respective seq_ids, llama_decode once, sample each sequence's next token using its own sampler chain at its assigned batch index, and resolve every request's promise. - sampleNext fast path: deliver the seed token sampled at decodePrompt time. This avoids the decoder round-trip for the very first sample after a prompt, where there is no input token to feedback-decode. - sampleNext steady-state path: queue a PendingRequest (input token = previously-sampled token, position = current KV count, sampler = this sequence's chain) and wait on a std::promise resolved by the decoder thread. - decodePrompt holds decode_mutex for the prompt's chunk decode and takes the seed sample inline while the prompt's logits are still resident in the shared context. - trimKV holds decode_mutex, calls llama_memory_seq_rm for this sequence's seq_id, and invalidates any pending seed/input so the caller has to re-prime via decodePrompt before the next sampleNext. The 200 µs window is the throughput knob. Multi-sequence workloads naturally fall into lockstep because each sequence resubmits as soon as its sample returns, so successive requests usually arrive within the window without any caller-side coordination. Single-sequence callers pay one window's worth of latency per token (~2% of a ~10 ms decode); acceptable. Tunable later via a setter if needed. Cancellation ------------ - Existing one-way atomic flag preserved. - Checked at sampleNext entry (returns immediately) and again in processBatch after llama_decode but before sampling (skips wasted sample work, returns was_cancelled=true). The decode slot for a cancelled token is still consumed, which is fine — the slot is cheap; the win is not running the sampler. Tests ----- Added two integration tests gated on COTABBY_TEST_MODEL_PATH: - testInterleavedMultiSequenceSampling: alternates sampleNext between two sequences with greedy sampling and identical prompts, asserts identical output (validates the seed-token / feedback-decode handoff and per-sequence sampler isolation in the shared context). - testCancellationStopsSamplingPromptly: verifies sampleNext after cancelSequence returns was_cancelled=true without model work. Existing testEndToEndWithModel passes unchanged. Follow-ups ---------- - Bench scenario c_engine_threaded that exercises the full engine via its public API from two threads, for end-to-end throughput validation (Phase 0 numbers above are at the raw llama.cpp level). - README update: the "no shared decode mutex, no contention" claim is no longer accurate. The new design has a single decode_mutex serializing access to one llama_context. The contention is productive — it enables batching — but the README should reflect the new model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #3. The Phase 0 spike measured a 1.43x-2.35x throughput win for batched decode on M-series Metal; this PR makes the engine actually use it.
(Re-opening because #4 closed when its base branch
feat/batched-decode-benchwas deleted on merge of #3. Same content, rebased onto main.)Summary
Replaces the per-sequence
llama_contextarchitecture with a single shared context (n_seq_max = MAX_SEQUENCES) and a dedicated decoder thread that coalesces sample-step requests from multiple sequences into onellama_decodecall. Public C++ API (CotabbyInferenceEngine.h) is unchanged; Cotabby's Swift code does not need to be modified.Design
llama_context,n_ctx = configured * MAX_SEQUENCES. EachSequenceStateowns allama_seq_idslot (0..MAX_SEQUENCES-1).BATCH_WINDOW_MICROS(200 µs) for siblings, build onellama_batchcarrying all pending tokens with their seq_ids,llama_decodeonce, sample each sequence's next token using its own sampler chain at its assigned batch index, resolve every request'sstd::promise.sampleNextfast path: deliver the seed token sampled atdecodePrompttime. No decoder round-trip on the first sample after a prompt.sampleNextsteady-state path: queue aPendingRequestand wait on the promise.decodePromptholdsdecode_mutexfor its chunk decodes and takes the seed sample inline while the prompt's logits are still resident.trimKVholdsdecode_mutex, callsllama_memory_seq_rm, and invalidates pending seed/input so the caller re-primes viadecodePrompt.The 200 µs window is the throughput knob. Multi-sequence workloads naturally fall into lockstep because each sequence resubmits as soon as its sample returns. Single-sequence callers pay one window per token (~2% of a ~10 ms decode).
Validation
Two new tests gated on
COTABBY_TEST_MODEL_PATH:testInterleavedMultiSequenceSampling: alternatessampleNextbetween two sequences with greedy sampling and identical prompts; asserts identical token output. Validates the seed-token / feedback-decode handoff and per-sequence sampler isolation in the shared context.testCancellationStopsSamplingPromptly: verifiessampleNextaftercancelSequencereturnswas_cancelled=truewithout model work.Existing
testEndToEndWithModelpasses unchanged — verifies the API stayed source-compatible for the single-sequence flow Cotabby uses today.Follow-ups
c_engine_threadedthat exercises the engine via its public API from two threads, for end-to-end throughput validation.decode_mutexserializing the singlellama_context. The contention is productive (it enables batching), but the docs should reflect the new model.tabby-4): no code changes needed for correctness; SPM ref pinning + comment touch-ups in a small follow-up PR.Risk / rollout notes
n_ctxis multiplied byMAX_SEQUENCES = 4, so KV cache memory ~4x. For Gemma 3 1B with 2048 ctx, that's ~107 MB — tolerable on M-series. Worth checking on lower-spec hardware.decodePromptholdsdecode_mutexfor its entire body, which serializes against the decoder thread. Verified with the interleaved-multi-sequence test.