Skip to content

[None][fix] KVCacheManagerV2 bug fixes (V2 remains default OFF)#12306

Merged
yizhang-nv merged 4 commits intoNVIDIA:mainfrom
yizhang-nv:kv-cache-v2-fixes
Apr 24, 2026
Merged

[None][fix] KVCacheManagerV2 bug fixes (V2 remains default OFF)#12306
yizhang-nv merged 4 commits intoNVIDIA:mainfrom
yizhang-nv:kv-cache-v2-fixes

Conversation

@yizhang-nv
Copy link
Copy Markdown
Member

@yizhang-nv yizhang-nv commented Mar 18, 2026

@coderabbitai summary

Description

Port comprehensive bug fixes for KVCacheManagerV2 while keeping V2 default OFF (use_kv_cache_manager_v2=False). The fallback mechanism now warns instead of silently switching to V1.

Fixes

KV Cache Core

  • Partial block rebase corruption: Fixed condition in _kv_cache.py where a full tree block could corrupt shared pages during partial rebase (added is_full guard).
  • get_batch_cache_indices default layer_id: Made layer_id optional (default pool 0) to avoid mandatory argument when layer is irrelevant.
  • max_blocks_per_seq undercount: Computed after max_seq_len clamping and now accounts for num_extra_kv_tokens + max_total_draft_tokens so the host page-index buffer is large enough.
  • Multimodal block reuse: Different images/videos sharing the same placeholder token ID were incorrectly matched in the radix tree. Added _augment_tokens_for_block_reuse using content digest (Blake3 hash) for distinct tree entries.

KV Cache Manager

  • Auto-provision host cache tier: V2 MAX_UTILIZATION scheduler relies on suspend/resume; without a host tier, suspended pages have nowhere to go, causing deadlock. Now auto-provisions host tier matching GPU quota (capped at 50% available host memory).
  • extend_capacity_for_tokens: Added to both KVCacheManager and KVCacheManagerV2 for CUDA graph padding token capacity extension.
  • Float type in token estimation: Wrapped max_num_tokens_in_memory computation with int() to prevent float propagation.
  • Draft KV cache stream: Fixed resume() to use draft_kv_cache_manager._stream instead of torch.cuda.current_stream().
  • Draft resource release: Fixed release_resources to also free draft KV cache on context scheduling failure.
  • Context update resize: Simplified resize to use None capacity for context phase, removing fragile draft OOM workaround.

Scheduler V2

  • Two-phase scheduling: Context requests are now deferred to phase 2 so generation requests are fully budgeted first, preventing PEFT adapter eviction failures.
  • PEFT pre-claim for GENERATION_TO_COMPLETE: Pre-claims PEFT pages for requests whose adapters haven't been released yet (overlap executor timing).
  • Deadlock detection: Raises RuntimeError when generation requests exist but none can be scheduled or evicted (KV cache exhausted with no host tier).

CUDA Graph

  • Prevent on-the-fly capture: Added _capture_allowed flag and allow_capture() context manager to prevent workspace tensor reallocation from invalidating existing graphs.
  • Disable graph replay during general warmup: Avoids replaying graphs with stale KV cache block offsets from capture time.
  • Eager fallback for uncaptured keys: Falls back to eager execution when a key was not captured during warmup instead of crashing.

Speculative Decoding

  • CUDA graph padding KV cache rewind mismatch: pad_draft_tokens_for_cuda_graph now extends KV cache capacity to match padded length, preventing rewind underflow. Implemented in drafter.py, model_drafter.py (two-model), and ngram.py.

Fallback

  • Warn-only fallback: Refactored V2 incompatibility check (beam search, kv_connector, event buffer, cache transceiver) to warn instead of silently switching to V1.

Test Coverage

  • Existing unit tests in tests/unittest/
  • Integration tests for multimodal models
  • CI pre-merge pipeline

PR Checklist

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 18, 2026

📝 Walkthrough

Walkthrough

Two files were modified to add multimodal token augmentation support for block reuse in KV cache management and to refine the fallback conditions triggering downgrade from KVCacheManagerV2 to KVCacheManager when cache_transceiver_config with a non-None backend is present.

Changes

Cohort / File(s) Summary
Fallback condition updates
tensorrt_llm/_torch/pyexecutor/_util.py
Updated KvCacheCreator.init and _create_one_model_draft_kv_cache_manager to include cache_transceiver_config with non-None backend as an additional trigger for KVCacheManagerV2 fallback, resulting in downgrade to KVCacheManager with warning.
Multimodal token augmentation and API refinements
tensorrt_llm/_torch/pyexecutor/resource_manager.py
Added _augment_tokens_for_block_reuse method to both KVCacheManager and KVCacheManagerV2 classes to augment multimodal token regions for block reuse. Updated get_batch_cache_indices signature to accept optional layer_id parameter. Modified prepare_resources and update_resources to augment tokens when block reuse is enabled. Imported gen_multi_modal_tokens and adjusted BAD_PAGE_INDEX fallback logic in get_block_ids_per_seq.

Sequence Diagram(s)

sequenceDiagram
    participant Request as LlmRequest
    participant Manager as KVCacheManager/<br/>KVCacheManagerV2
    participant Augment as _augment_tokens_<br/>for_block_reuse
    participant GenTokens as gen_multi_modal_<br/>tokens
    participant Cache as KV Cache/<br/>Commit

    Request->>Manager: prepare_resources()
    alt Block reuse enabled & context chunk is first
        Manager->>Augment: _augment_tokens_for_block_reuse(tokens)
        Augment->>GenTokens: gen_multi_modal_tokens(multimodal regions)
        GenTokens-->>Augment: embedding content digests
        Augment-->>Manager: augmented tokens
        Manager->>Cache: _create_kv_cache(augmented_tokens)
    else Block reuse disabled
        Manager->>Cache: _create_kv_cache(original_tokens)
    end

    Request->>Manager: update_resources()
    alt Block reuse enabled & uncommitted tokens exist
        Manager->>Augment: _augment_tokens_for_block_reuse(committed_range)
        Augment->>GenTokens: gen_multi_modal_tokens()
        GenTokens-->>Augment: embedding content digests
        Augment-->>Manager: augmented tokens
        Manager->>Cache: commit(augmented_tokens)
    else
        Manager->>Cache: commit(original_tokens)
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the main change as KVCacheManagerV2 bug fixes with the important clarification that V2 remains default OFF, directly matching the PR's primary objective.
Description check ✅ Passed The PR description is comprehensive with clear Description and Test Coverage sections, though the PR Checklist is not fully marked and some template items (dependencies, CODEOWNERS, documentation, diagrams) show no evidence of completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)

39-40: Use a module import here to keep namespace-qualified usage.

This new direct function import diverges from the repo’s Python import rule and makes internal API provenance less clear.

💡 Suggested diff
-from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import \
-    gen_multi_modal_tokens
+from tensorrt_llm.runtime.kv_cache_manager_v2 import \
+    _block_radix_tree
-            mm_tokens = gen_multi_modal_tokens(self.vocab_size, digest, length)
+            mm_tokens = _block_radix_tree.gen_multi_modal_tokens(
+                self.vocab_size, digest, length)

As per coding guidelines "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use from package.subpackage import foo then foo.SomeClass())."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py` around lines 39 - 40,
Replace the direct function import of gen_multi_modal_tokens with a module
import so the callsite remains namespace-qualified: import the module
tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree (instead of "from ...
import gen_multi_modal_tokens") and update all uses of gen_multi_modal_tokens in
resource_manager.py to call _block_radix_tree.gen_multi_modal_tokens so the
internal API provenance is preserved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/_util.py`:
- Around line 127-129: The fallback warning for KVCacheManagerV2 is missing the
new cache transceiver condition; update the logger.warning message (the call
that currently mentions kv_connector_manager, beam width, and event buffer max
size) to also reference the cache transceiver condition (e.g., "cache
transceiver enabled" or the actual flag name used in code such as
cache_transceiver or cache_transceiver_enabled) so the predicate matches the
real checks that trigger the fallback from KVCacheManagerV2 to KVCacheManager.

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py`:
- Around line 2452-2457: free_resources() currently calls kv_cache.commit(...)
with raw tokens when committing an uncommitted tail, which can mix non-canonical
radix entries; update free_resources() to mirror the other commit path by
passing tokens through _augment_tokens_for_block_reuse before calling
kv_cache.commit. Specifically, locate the commit call in free_resources() and
replace the raw tokens argument (from req.get_tokens(...) or similar) with a
call to self._augment_tokens_for_block_reuse(req.get_tokens(DEFAULT_BEAM_INDEX),
req, start=kv_cache.num_committed_tokens, end=req.context_current_position) so
the same canonicalization logic used at the Line 2452–2457 path is applied
consistently.
- Around line 1967-1969: The for-loop that iterates over req.multimodal_hashes,
req.multimodal_positions, and req.multimodal_lengths uses zip() which silently
truncates mismatched lists; change the zip call to zip(req.multimodal_hashes,
req.multimodal_positions, req.multimodal_lengths, strict=True) so Python raises
on length mismatches and prevents silent token-augmentation errors; update the
loop where those symbols (req.multimodal_hashes, req.multimodal_positions,
req.multimodal_lengths) are zipped to include strict=True and run tests to
confirm a ValueError is raised for mismatched lengths.

---

Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py`:
- Around line 39-40: Replace the direct function import of
gen_multi_modal_tokens with a module import so the callsite remains
namespace-qualified: import the module
tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree (instead of "from ...
import gen_multi_modal_tokens") and update all uses of gen_multi_modal_tokens in
resource_manager.py to call _block_radix_tree.gen_multi_modal_tokens so the
internal API provenance is preserved.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c87a564d-166c-41ce-b64b-6deba6c4f38b

📥 Commits

Reviewing files that changed from the base of the PR and between 07e2440 and ff9f1e3.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py

Comment thread tensorrt_llm/_torch/pyexecutor/_util.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/resource_manager.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/resource_manager.py
@nvpohanh
Copy link
Copy Markdown
Collaborator

@lowsfer could you review this? thanks

@yizhang-nv yizhang-nv requested a review from a team as a code owner March 23, 2026 09:52
@yizhang-nv yizhang-nv requested a review from ziyixiong-nv March 23, 2026 09:52
@yizhang-nv yizhang-nv changed the title [None][fix] Fix KVCacheManagerV2 fallback, block index, and multimodal block reuse [None][fix] KVCacheManagerV2 bug fixes (V2 remains default OFF) Mar 23, 2026
@yizhang-nv yizhang-nv force-pushed the kv-cache-v2-fixes branch 2 times, most recently from 6547686 to 8c6cc5e Compare March 23, 2026 10:26
@yizhang-nv yizhang-nv requested a review from lancelly March 23, 2026 10:27
@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40046 [ run ] triggered by Bot. Commit: fdf7bb0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40046 [ run ] completed with state SUCCESS. Commit: fdf7bb0
/LLM/main/L0_MergeRequest_PR pipeline #31200 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yizhang-nv yizhang-nv requested a review from mikeiovine March 30, 2026 07:00
@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42321 [ run ] triggered by Bot. Commit: 0b503ab Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42321 [ run ] completed with state SUCCESS. Commit: 0b503ab
/LLM/main/L0_MergeRequest_PR pipeline #33111 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot kill

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44353 [ run ] completed with state FAILURE. Commit: 82238df
/LLM/main/L0_MergeRequest_PR pipeline #34769 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44383 [ run ] triggered by Bot. Commit: ea80f05 Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44459 [ run ] triggered by Bot. Commit: ea80f05 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44383 [ run ] completed with state SUCCESS. Commit: ea80f05
/LLM/main/L0_MergeRequest_PR pipeline #34797 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44459 [ run ] completed with state SUCCESS. Commit: ea80f05
/LLM/main/L0_MergeRequest_PR pipeline #34865 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44658 [ run ] triggered by Bot. Commit: 9a12dcf Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44701 [ run ] triggered by Bot. Commit: 11302e6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44701 [ run ] completed with state SUCCESS. Commit: 11302e6
/LLM/main/L0_MergeRequest_PR pipeline #35065 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

Port comprehensive bug fixes for KVCacheManagerV2 while keeping V2
default OFF. The fallback mechanism now warns instead of switching.

Fixes: partial block rebase corruption, max_blocks_per_seq undercount,
multimodal block reuse, auto-provision host cache tier, CUDA graph
capture/replay safety, draft KV cache stream and resource release,
two-phase scheduler with PEFT pre-claim and deadlock detection,
speculative decoding padding KV cache rewind mismatch, float type
propagation in token estimation.

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Remove _extend_kv_cache_for_padding from ModelDrafter and NGramDrafter
(padding is now handled in the drafter base class), simplify scheduler_v2
and resource_manager interfaces, and remove leftover debug prints from
KV cache manager V2.

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
…d correctness fixes

- Consolidate CUDA graph skip logic into CUDAGraphRunner.maybe_get_cuda_graph
- Fix get_num_available_tokens over-subtraction of extra_tokens
- Unify get_batch_cache_indices parameter name (layer_id -> layer_idx)
- Add defensive assert for multimodal hash length
- Update warning message to include cache transceiver condition
- Restore upstream _is_prop_supported implementation
- Fix copyright year, TYPE_CHECKING import, return type annotation

Signed-off-by: Yi Zhang <yizhang@nvidia.com>
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Port test_encoder_counts_toward_batch fix from enable-v2-by-default-v2:
increase max_batch_size to 2 so gen request wins phase 1 and encoder
fills the remaining slot in phase 2, verifying encoder correctly
counts toward batch size.

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44935 [ run ] triggered by Bot. Commit: 53ce2fd Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45081 [ run ] triggered by Bot. Commit: 53ce2fd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45081 [ run ] completed with state SUCCESS. Commit: 53ce2fd
/LLM/main/L0_MergeRequest_PR pipeline #35379 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yizhang-nv
Copy link
Copy Markdown
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45297 [ run ] triggered by Bot. Commit: 53ce2fd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45297 [ run ] completed with state SUCCESS. Commit: 53ce2fd
/LLM/main/L0_MergeRequest_PR pipeline #35551 completed with status: 'SUCCESS'

CI Report

Link to invocation

@yizhang-nv yizhang-nv merged commit 371e38d into NVIDIA:main Apr 24, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants