Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions olive/common/hf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,18 @@ def create_packed(self) -> nn.Linear:
class LayerWrapper:
"""Wrapper for transformer layer block."""

FIRST_LAYER_NORM = {"default": "input_layernorm", "gpt2": "ln_1", "opt": "self_attn_layer_norm", "qwen": "ln_1"}
FIRST_LAYER_NORM = {
"default": "input_layernorm",
"gpt2": "ln_1",
"lfm2": "operator_norm",
"opt": "self_attn_layer_norm",
"qwen": "ln_1",
}
SECOND_LAYER_NORM = {
"default": "post_attention_layernorm",
"gemma2": "pre_feedforward_layernorm",
"gpt2": "ln_2",
"lfm2": "ffn_norm",
"opt": "final_layer_norm",
"qwen": "ln_2",
}
Expand All @@ -109,14 +116,16 @@ class LayerWrapper:
"default": ["o_proj"],
"bloom": ["dense"],
"gpt2": ["c_proj"],
"lfm2": ["out_proj"],
"opt": ["out_proj"],
"qwen": ["c_proj"],
}
MLP = {"default": "mlp", "opt": ""}
MLP = {"default": "mlp", "lfm2": "feed_forward", "opt": ""}
MLP_INPUTS = {
"default": ["gate_proj", "up_proj"],
"bloom": ["dense_h_to_4h"],
"gpt2": ["c_fc"],
"lfm2": ["w1", "w3"],
"opt": ["fc1"],
"phi3": ["gate_up_proj"],
"qwen": ["w1", "w2"],
Expand All @@ -125,6 +134,7 @@ class LayerWrapper:
"default": ["down_proj"],
"bloom": ["dense_4h_to_h"],
"gpt2": ["c_proj"],
"lfm2": ["w2"],
"opt": ["fc2"],
"qwen": ["c_proj"],
}
Expand All @@ -134,8 +144,8 @@ def __init__(self, layer: nn.Module, model_type: str):
self.layer = layer
self.model_type = model_type

# Use fail_on_not_found=False to support hybrid architectures (e.g., Qwen3.5)
# where some layers use linear attention instead of standard self-attention
# Use fail_on_not_found=False to support hybrid architectures (e.g., Qwen3.5, LFM2)
# where some layers lack standard self-attention (linear attention or conv layers)
self.attn, self.attn_name = get_submodules(
layer, self.ATTENTION, self.model_type, return_name=True, fail_on_not_found=False
)
Expand Down Expand Up @@ -208,7 +218,12 @@ class ModelWrapper:
"qwen": "transformer.rotary_emb",
}
LM_HEAD = {"default": "lm_head"}
PRE_HEAD_LAYERNORM = {"default": "model.norm", "gpt2": "transformer.ln_f", "qwen": "transformer.ln_f"}
PRE_HEAD_LAYERNORM = {
"default": "model.norm",
"gpt2": "transformer.ln_f",
"lfm2": "model.embedding_norm",
"qwen": "transformer.ln_f",
}
LAYERS = {
"default": "model.layers",
"bloom": "transformer.h",
Expand Down
8 changes: 7 additions & 1 deletion olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
module_map = {
"model.embed_tokens": self.embedding,
"model.norm": self.final_norm,
"model.embedding_norm": self.final_norm, # LFM2 uses embedding_norm instead of norm
"lm_head": self.lm_head,
**{f"model.layers.{i}": layer for i, layer in enumerate(self.layers)},
}
Expand All @@ -422,8 +423,13 @@ def set_tensor(module, tensor_name, tensor_value, local_bits, local_group_size):
for sub_name in tensor_name.split(".")[:-1]:
if sub_name.isdigit():
submodule = submodule[int(sub_name)]
else:
elif hasattr(submodule, sub_name):
submodule = getattr(submodule, sub_name)
else:
# Create missing submodule for hybrid architectures (e.g., LFM2 conv layers)
child = QuantizedTensorModule()
setattr(submodule, sub_name, child)
submodule = child
if isinstance(submodule, QuantizedTensorModule):
for q_attr, q_value in [("bits", local_bits), ("_group_size", local_group_size)]:
setattr(submodule, q_attr, q_value)
Expand Down
6 changes: 6 additions & 0 deletions olive/passes/pytorch/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def rotate_model(
@classmethod
def fuse_ln_linear(cls, layernorm: nn.Module, linear_layers: Iterable[nn.Linear]):
"""Fuse the linear operations in Layernorm into the adjacent linear blocks."""
# Hybrid models (e.g., LFM2) may have layers without attention, passing an empty list.
# Skip early to avoid resetting layernorm weights when there are no linears to fuse into.
linear_layers = list(linear_layers)
if not linear_layers:
return

for linear in linear_layers:
linear_dtype = linear.weight.dtype

Expand Down
68 changes: 68 additions & 0 deletions test/common/test_hf_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,71 @@ def test_hf_wrapper(model_path, tmp_path):
assert isinstance(module, nn.Linear)
for name in names:
assert name.startswith("self_attn.qkv_proj")


def test_hf_wrapper_lfm2():
"""Test LayerWrapper with LFM2 hybrid model (conv + attention layers)."""
from olive.model import HfModelHandler

input_model = HfModelHandler(model_path="tiny-random/lfm2")
model_wrapper = ModelWrapper(input_model.get_hf_model_config())

assert model_wrapper.model_type == "lfm2"

loaded_model = input_model.load_model()
model_wrapper.set_model(loaded_model)

# high-level submodules
assert isinstance(model_wrapper.get_embeds(False)[0], nn.Embedding)
assert isinstance(model_wrapper.get_lm_head(False), nn.Linear)
assert model_wrapper.get_pre_head_layernorm(False).__class__.__name__.endswith("RMSNorm")

layer_wrappers = model_wrapper.get_layer_wrappers()
assert len(layer_wrappers) == model_wrapper.num_hidden_layers

has_attn_layer = False
has_conv_layer = False

for layer_wrapper in layer_wrappers:
# all layers have layernorms and MLP
assert layer_wrapper.get_first_layer_norm(False).__class__.__name__.endswith("RMSNorm")
assert layer_wrapper.get_second_layer_norm(False).__class__.__name__.endswith("RMSNorm")

mlp_modules, mlp_names = layer_wrapper.get_mlp_inputs()
assert len(mlp_modules) == 2
for m in mlp_modules:
assert isinstance(m, nn.Linear)
for n in mlp_names:
assert n.startswith("feed_forward.")

mlp_out_modules, mlp_out_names = layer_wrapper.get_mlp_outputs()
assert len(mlp_out_modules) == 1
assert isinstance(mlp_out_modules[0], nn.Linear)
assert mlp_out_names[0].startswith("feed_forward.")

if layer_wrapper.attn is not None:
# attention layer
has_attn_layer = True
attn_modules, attn_names = layer_wrapper.get_attention_inputs()
assert len(attn_modules) == 3
for m in attn_modules:
assert isinstance(m, nn.Linear)

attn_out_modules, attn_out_names = layer_wrapper.get_attention_outputs()
assert len(attn_out_modules) == 1
assert isinstance(attn_out_modules[0], nn.Linear)
assert attn_out_names[0].startswith("self_attn.")
else:
# conv layer — attention methods return empty
has_conv_layer = True
attn_modules, attn_names = layer_wrapper.get_attention_inputs()
assert attn_modules == []
assert attn_names == []

attn_out_modules, attn_out_names = layer_wrapper.get_attention_outputs()
assert attn_out_modules == []
assert attn_out_names == []

# LFM2 must have both layer types
assert has_attn_layer, "Expected at least one attention layer"
assert has_conv_layer, "Expected at least one conv layer"