Skip to content

Add output_shapes for AddMM#3262

Open
pHequals7 wants to merge 3 commits intoml-explore:mainfrom
pHequals7:fix/addmm-slice-customkernel-output-shapes
Open

Add output_shapes for AddMM#3262
pHequals7 wants to merge 3 commits intoml-explore:mainfrom
pHequals7:fix/addmm-slice-customkernel-output-shapes

Conversation

@pHequals7
Copy link

@pHequals7 pHequals7 commented Mar 16, 2026

Summary

Adds output_shapes() to AddMM, enabling compile(shapeless=True) for models with biased Linear layers.

Change

AddMM::output_shapes returns inputs[0].shape() (the C matrix shape), which is already validated to match the output shape at construction in ops.cpp.

Context

Most transformer models use biased Linear layers, which dispatch through AddMM. Without this, compile(shapeless=True) throws "primitive does not have shape inference implemented". This follows the same pattern as #2601 (Convolution::output_shapes) and #1727 (shapeless SliceUpdate + Broadcast).

Discovered while porting mlx-whisper to Swift — whisper-small has 145 biased Linear layers that all fail in shapeless compile without this.

Files Changed

File Lines Change
mlx/primitives.h +1 Declaration
mlx/primitives.cpp +6 Implementation

Previous scope

Originally included Slice and CustomKernel output_shapes — dropped per review feedback. CustomKernel shape inference via a callback function could be a separate discussion/PR.

Enable `compile(shapeless: true)` for models that use:

1. **AddMM** (biased Linear layers): Most transformer models use Linear
   with bias=true, which dispatches AddMM. Without output_shapes,
   compile(shapeless:true) fails. The fix matches Matmul::output_shapes.

2. **Slice** (array subscripting): Any compiled function that slices
   arrays (e.g., `array[0..<N]`) needs Slice::output_shapes. The
   implementation re-normalizes slice bounds against runtime input shape.
   Limited to constant-dimension slices; variable-dimension slices
   should use take()/DynamicSlice.

3. **CustomKernel** (metalKernel API): Custom Metal kernels created via
   the metalKernel() API can now work inside compile(shapeless:true).
   Output shapes are stored at construction time and returned during
   compile-time shape inference. A -1 sentinel in output shapes triggers
   dynamic computation from input sizes (total_input_size / num_outputs),
   enabling kernels with variable output sizes (e.g., KV cache append).

Discovered while porting mlx-whisper to Swift using mlx-swift. All three
primitives are essential for compiled inference with custom fused kernels.
Comment on lines +348 to +355
std::vector<Shape> AddMM::output_shapes(const std::vector<array>& inputs) {
// out = alpha * (A @ B) + beta * C
// Output shape matches C (inputs[0]), with last dim from B (inputs[2])
auto out_shape = inputs[0].shape();
out_shape.back() = inputs[2].shape(-1);
return {std::move(out_shape)};
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wouldn't the out shape just be the c shape ?

From ops.cpp

  if (c.shape() != out_shape) {
    throw std::invalid_argument(
        "[addmm] input c must broadcast to the output shape");
  }

  auto out = array(
      std::move(out_shape),
      out_type,
      std::make_shared<AddMM>(to_stream(s), alpha, beta),
      {a, b, c});

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — you're right, since c.shape() is already validated against out_shape at construction, we can just return it directly. Simplified in 95382f9.

C (inputs[0]) is already validated to match the output shape at
construction in ops.cpp, so we can return its shape directly
instead of recalculating from B's last dimension.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
if (resolved_shapes[i][j] == -1) resolved_shapes[i][j] = per_out;
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not doing what you were thinking it was doing.

resolved_shapes.size() is the number of outputs. total is the sum of number of elements of all inputs. What is per_out supposed to represent?

Say for instance I write a custom kernel that adds two arrays. per_out would be 2 times the input size 🤷‍♂️

The only way to do this properly is to pass a function that computes the output shapes from the input shapes. If this function is passed then shapeless compilation of the custom kernel will be automatically enabled otherwise not.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you're right, the -1 sentinel was a hack that only worked for my specific kv cache concat case. a shape inference function passed to metalKernel() makes way more sense as a general api.

stripped this pr down to just AddMM which is the straightforward one. happy to open a separate issue for the CustomKernel shape inference function if that's useful — or leave it for someone with better context on the compile internals.

Comment on lines +4786 to +4787
// Works for constant-dimension slices; variable-dimension slices
// should use take()/DynamicSlice instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Works for constant-dimension slices; variable-dimension slices
// should use take()/DynamicSlice instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, dropped it. makes sense that constant slices don't need this.

Drop Slice and CustomKernel changes per review feedback:
- Slice::output_shapes unnecessary for constant-dimension slices
- CustomKernel -1 sentinel is not a general solution; proper
  approach is a shape inference function (separate discussion)

Keeping only AddMM::output_shapes which is straightforward —
C's shape is already validated to match the output at construction.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pHequals7 pHequals7 changed the title Add output_shapes for AddMM, Slice, and CustomKernel Add output_shapes for AddMM Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants