Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
552 changes: 276 additions & 276 deletions demos/LLaVA.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/integration/model_bridge/test_bridge_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def test_text_generation(gpt2_bridge):
assert len(output) > len(prompt), "Generated text should be longer than the prompt"


@pytest.mark.skip(reason="KV cache support for TransformerBridge is currently incomplete")
def test_generate_with_kv_cache():
"""Test that generate works with use_past_kv_cache parameter."""
model_name = "gpt2" # Use a smaller model for testing
Expand Down
64 changes: 59 additions & 5 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,13 @@ def forward(
)
else:
input_ids = input

# Detect inputs_embeds: if the tensor is floating point, it's pre-computed
# embeddings (e.g., from multimodal models) rather than token IDs.
_is_inputs_embeds = (
isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point()
)

if attention_mask is not None:
kwargs["attention_mask"] = attention_mask
if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False):
Expand Down Expand Up @@ -1295,9 +1302,10 @@ def forward(
"Audio models require tensor input (raw waveform). "
"Pass a torch.Tensor or use input_values parameter."
)
elif _is_inputs_embeds:
output = self.original_model(inputs_embeds=input_ids, **kwargs)
else:
output = self.original_model(input_ids, **kwargs)

# Stash only the cache object (not the full output) for generate().
if getattr(self, "_capture_hf_cache", False):
self._last_hf_cache = getattr(output, "past_key_values", None)
Expand All @@ -1315,7 +1323,13 @@ def forward(
"Audio models do not support return_type='loss'. "
"CTC loss requires aligned frame-level labels."
)
# Use self.loss_fn for bfloat16 consistency (vs HF's cross_entropy)
if _is_inputs_embeds:
raise ValueError(
"Cannot compute loss with inputs_embeds — token IDs required for labels."
)
# Always use self.loss_fn for consistency with HT's formula
# (log_softmax + gather). HF's output.loss uses F.cross_entropy
# which gives different results in bfloat16.
assert isinstance(
logits, torch.Tensor
), f"Expected logits tensor, got {type(logits)}"
Expand All @@ -1326,6 +1340,10 @@ def forward(
"Audio models do not support return_type='both'. "
"CTC loss requires aligned frame-level labels."
)
if _is_inputs_embeds:
raise ValueError(
"Cannot compute loss with inputs_embeds — token IDs required for labels."
)
assert isinstance(
logits, torch.Tensor
), f"Expected logits tensor, got {type(logits)}"
Expand Down Expand Up @@ -1796,13 +1814,19 @@ def generate(
Generated sequence as string, list of strings, or tensor depending on input type and return_type.
If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
"""
# Convert input to tokens
# Convert input to tokens using to_tokens() for consistent special token handling
_generate_from_embeds = False
if isinstance(input, str):
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_type = "str"
elif isinstance(input, list):
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_type = "list"
elif isinstance(input, torch.Tensor) and input.is_floating_point():
# inputs_embeds: pre-computed embeddings (e.g., from multimodal models)
input_tokens = input.to(self.cfg.device)
input_type = "embeds"
_generate_from_embeds = True
else:
input_tokens = input.to(self.cfg.device)
input_type = "tokens"
Expand All @@ -1811,6 +1835,8 @@ def generate(
if return_type == "input":
if input_type in ["str", "list"]:
return_type = "str"
elif input_type == "embeds":
return_type = "tokens"
else:
return_type = "tokens"

Expand Down Expand Up @@ -1861,6 +1887,9 @@ def generate(

# Generate tokens
current_tokens = input_tokens.clone()
# For inputs_embeds generation, also track generated token IDs for decoding
if _generate_from_embeds:
generated_token_ids: list[torch.Tensor] = []
sampled_tokens_list = []

# For encoder-decoder models, keep encoder input fixed and grow decoder input
Expand Down Expand Up @@ -1888,6 +1917,10 @@ def generate(
)
else:
forward_kwargs: Dict[str, Any] = {}
# Pass multimodal inputs only on the first step — the vision
# encoder processes the image once, embedding it into the
# token sequence. This includes pixel_values plus any extra
# processor outputs (e.g. image_sizes for LlavaNext).
if gen_step_idx == 0:
if pixel_values is not None:
forward_kwargs["pixel_values"] = pixel_values
Expand Down Expand Up @@ -1924,6 +1957,13 @@ def generate(
logits_seq_list.append(final_logits.clone())

# Sample next token
# For inputs_embeds, we can't pass the embeddings to freq/rep penalty,
# so use the generated_token_ids for penalty tracking
penalty_tokens = (
torch.stack(generated_token_ids, dim=1)
if _generate_from_embeds and generated_token_ids
else None
)
if do_sample:
sampled_tokens = utils.sample_logits(
final_logits,
Expand All @@ -1932,14 +1972,18 @@ def generate(
temperature=temperature,
freq_penalty=freq_penalty,
repetition_penalty=repetition_penalty,
tokens=decoder_tokens if is_encoder_decoder else current_tokens,
tokens=penalty_tokens
if _generate_from_embeds
else (decoder_tokens if is_encoder_decoder else current_tokens),
).to(self.cfg.device)
else:
sampled_tokens = utils.sample_logits(
final_logits,
temperature=0.0,
repetition_penalty=repetition_penalty,
tokens=decoder_tokens if is_encoder_decoder else current_tokens,
tokens=penalty_tokens
if _generate_from_embeds
else (decoder_tokens if is_encoder_decoder else current_tokens),
).to(self.cfg.device)

sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
Expand All @@ -1959,6 +2003,13 @@ def generate(
decoder_tokens = torch.cat(
[decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1
)
elif _generate_from_embeds:
# For inputs_embeds: get the embedding of the new token and append
generated_token_ids.append(sampled_tokens)
embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator]
assert embed_fn is not None
new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype)
current_tokens = torch.cat([current_tokens, new_embed], dim=1)
else:
current_tokens = torch.cat(
[current_tokens, sampled_tokens.unsqueeze(1)], dim=1
Expand All @@ -1978,6 +2029,9 @@ def generate(
sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
if is_encoder_decoder:
output_tokens = decoder_tokens
elif _generate_from_embeds:
# For inputs_embeds, we only have the generated token IDs (no input token IDs)
output_tokens = sampled_tokens
else:
output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> torch.Tensor:
"""Forward pass through the vision encoder layer.

Args:
hidden_states: Input hidden states from previous layer
attention_mask: Optional attention mask
causal_attention_mask: Optional causal attention mask (used by CLIP encoder)
**kwargs: Additional arguments

Returns:
Expand All @@ -74,7 +76,12 @@ def forward(
)

hidden_states = self.hook_in(hidden_states)
output = self.original_component(hidden_states, attention_mask=attention_mask, **kwargs)
output = self.original_component(
hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
**kwargs,
)

if isinstance(output, tuple):
output = (self.hook_out(output[0]),) + output[1:]
Expand Down
Loading