Conversation
Enable `compile(shapeless: true)` for models that use: 1. **AddMM** (biased Linear layers): Most transformer models use Linear with bias=true, which dispatches AddMM. Without output_shapes, compile(shapeless:true) fails. The fix matches Matmul::output_shapes. 2. **Slice** (array subscripting): Any compiled function that slices arrays (e.g., `array[0..<N]`) needs Slice::output_shapes. The implementation re-normalizes slice bounds against runtime input shape. Limited to constant-dimension slices; variable-dimension slices should use take()/DynamicSlice. 3. **CustomKernel** (metalKernel API): Custom Metal kernels created via the metalKernel() API can now work inside compile(shapeless:true). Output shapes are stored at construction time and returned during compile-time shape inference. A -1 sentinel in output shapes triggers dynamic computation from input sizes (total_input_size / num_outputs), enabling kernels with variable output sizes (e.g., KV cache append). Discovered while porting mlx-whisper to Swift using mlx-swift. All three primitives are essential for compiled inference with custom fused kernels.
| std::vector<Shape> AddMM::output_shapes(const std::vector<array>& inputs) { | ||
| // out = alpha * (A @ B) + beta * C | ||
| // Output shape matches C (inputs[0]), with last dim from B (inputs[2]) | ||
| auto out_shape = inputs[0].shape(); | ||
| out_shape.back() = inputs[2].shape(-1); | ||
| return {std::move(out_shape)}; | ||
| } | ||
|
|
There was a problem hiding this comment.
Why wouldn't the out shape just be the c shape ?
From ops.cpp
if (c.shape() != out_shape) {
throw std::invalid_argument(
"[addmm] input c must broadcast to the output shape");
}
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c});There was a problem hiding this comment.
Good catch — you're right, since c.shape() is already validated against out_shape at construction, we can just return it directly. Simplified in 95382f9.
C (inputs[0]) is already validated to match the output shape at construction in ops.cpp, so we can return its shape directly instead of recalculating from B's last dimension. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mlx/backend/metal/custom_kernel.cpp
Outdated
| if (resolved_shapes[i][j] == -1) resolved_shapes[i][j] = per_out; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
I think this is not doing what you were thinking it was doing.
resolved_shapes.size() is the number of outputs. total is the sum of number of elements of all inputs. What is per_out supposed to represent?
Say for instance I write a custom kernel that adds two arrays. per_out would be 2 times the input size 🤷♂️
The only way to do this properly is to pass a function that computes the output shapes from the input shapes. If this function is passed then shapeless compilation of the custom kernel will be automatically enabled otherwise not.
There was a problem hiding this comment.
yeah you're right, the -1 sentinel was a hack that only worked for my specific kv cache concat case. a shape inference function passed to metalKernel() makes way more sense as a general api.
stripped this pr down to just AddMM which is the straightforward one. happy to open a separate issue for the CustomKernel shape inference function if that's useful — or leave it for someone with better context on the compile internals.
mlx/primitives.cpp
Outdated
| // Works for constant-dimension slices; variable-dimension slices | ||
| // should use take()/DynamicSlice instead. |
There was a problem hiding this comment.
| // Works for constant-dimension slices; variable-dimension slices | |
| // should use take()/DynamicSlice instead. |
There was a problem hiding this comment.
done, dropped it. makes sense that constant slices don't need this.
Drop Slice and CustomKernel changes per review feedback: - Slice::output_shapes unnecessary for constant-dimension slices - CustomKernel -1 sentinel is not a general solution; proper approach is a shape inference function (separate discussion) Keeping only AddMM::output_shapes which is straightforward — C's shape is already validated to match the output at construction. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
Adds
output_shapes()toAddMM, enablingcompile(shapeless=True)for models with biased Linear layers.Change
AddMM::output_shapesreturnsinputs[0].shape()(the C matrix shape), which is already validated to match the output shape at construction inops.cpp.Context
Most transformer models use biased Linear layers, which dispatch through AddMM. Without this,
compile(shapeless=True)throws "primitive does not have shape inference implemented". This follows the same pattern as #2601 (Convolution::output_shapes) and #1727 (shapeless SliceUpdate + Broadcast).Discovered while porting mlx-whisper to Swift — whisper-small has 145 biased Linear layers that all fail in shapeless compile without this.
Files Changed
mlx/primitives.hmlx/primitives.cppPrevious scope
Originally included Slice and CustomKernel
output_shapes— dropped per review feedback. CustomKernel shape inference via a callback function could be a separate discussion/PR.