Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
90619c5
Reapply "Attention bug fixes, tokamax splash defaulting logic (#282)"…
eltsai Dec 15, 2025
d848983
Reapply "Cross self attention switch (#251)" (#288)
eltsai Dec 15, 2025
c29fdc4
Disable unsafe rng
eltsai Dec 15, 2025
f68c7b0
Integrate tokamax ring attention as optional attention kernel for WAN…
eltsai Dec 17, 2025
8a18686
Merge branch 'main' into elisatsai_disable_unsafe_rng
eltsai Dec 29, 2025
a7fa4f0
Fixed formatting issue
eltsai Dec 30, 2025
41d9353
Updated scheduler test values
eltsai Dec 30, 2025
d128e32
Updated values based on v5p-8 tests
eltsai Dec 30, 2025
70ce989
Fixing ring attention
eltsai Jan 5, 2026
ed47e5f
moving kernel init outside the sharding map
eltsai Feb 10, 2026
65e7f93
Revert "moving kernel init outside the sharding map"
eltsai Feb 15, 2026
a0c377f
jitting and sharding vae, refactored for loops in jitted VAE, 132 sec…
eltsai Feb 23, 2026
e7cd3c4
Renaming VAE sharding axis to vae_spatial
eltsai Feb 26, 2026
c236d56
Renaming VAE sharding axis to vae_spatial
eltsai Feb 26, 2026
9bcd458
ring-attention
coolkp Mar 2, 2026
0e60bbb
Merge remote-tracking branch 'origin/kunjanp-ring-attention' into eli…
eltsai Mar 4, 2026
10f2f33
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
eltsai Mar 4, 2026
ffd7933
fixing attention from merging main
eltsai Mar 5, 2026
62e3b06
Fix attention_flax API regression from manual edits regarding context…
eltsai Mar 5, 2026
0a7d593
Merge branch 'elisatsai_ring_attention' of https://github.com/AI-Hype…
eltsai Mar 5, 2026
115fffa
Added sharding on ROPE
eltsai Mar 10, 2026
e04e78d
cfg cache
Mar 9, 2026
5b91824
Merged CFG cache, 220 sec using tokamax_flash
eltsai Mar 11, 2026
2d4eae1
Changed profiling logic
eltsai Mar 12, 2026
438fefd
Format fix
eltsai Mar 16, 2026
dff5c30
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
eltsai Mar 16, 2026
7293017
updated vae config logic to be the consistent, update xprof logic
eltsai Mar 19, 2026
b193301
feat: sync pyink, add splash_attention __init__, and exclude kernel t…
eltsai Mar 30, 2026
5823603
Merge origin/main into elisatsai_ring_attention
eltsai Mar 30, 2026
7375d6e
fix: reformat attention_ltx2.py jnp.clip lines to pass pyink formatter
eltsai Mar 30, 2026
768416a
Fix pylink error
eltsai Mar 30, 2026
0fa8678
fixing kernel precision
eltsai Apr 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --ignore=src/maxdiffusion/kernels/splash_attention -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
7 changes: 3 additions & 4 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand All @@ -60,7 +61,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'tokamax_flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand All @@ -81,9 +82,7 @@ flash_block_sizes: {
"block_q_dkv" : 512,
"block_kv_dkv" : 512,
"block_kv_dkv_compute" : 512,
"block_q_dq" : 512,
"block_kv_dq" : 512,
"use_fused_bwd_kernel": False,
"use_fused_bwd_kernel": True
}
# Use on v6e
# flash_block_sizes: {
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ dataset_name: ''
train_split: 'train'
dataset_type: 'tfrecord'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1.0
per_device_batch_size: 0.125
compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
Expand Down
Empty file.
15 changes: 15 additions & 0 deletions src/maxdiffusion/kernels/splash_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Splash Attention kernels."""
268 changes: 268 additions & 0 deletions src/maxdiffusion/kernels/splash_attention/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base functionality for Sparse Flash Attention."""

import functools
from typing import Final, NamedTuple, TypeAlias
import jax
import jax.numpy as jnp
import numpy as np
from . import splash_attention_mask_info as mask_info_lib


MaskInfo = mask_info_lib.MaskInfo


DEFAULT_MASK_VALUE: Final[float] = -0.7 * float(np.finfo(np.dtype("float32")).max)


class SegmentIds(NamedTuple):
"""SegmentIds for Q and KV sequences.

SegmentIds are a mechanism to ensure that there is no cross-attention between
segments (fraction of a sequence) that have been concatenated together into a
sequence. Each array is a list of ids (integers). Only tokens with the same
id are allowed to attend to each other.

The static mask (e.g. causal) is "and-ed" with the segment id mask to form
the actual attention mask. It is important that the latter does not have any
all-zero rows (along dimension kv). Otherwise it would result in a invalid
softmax (the denominator would be 0).
This condition holds for causal self-attention because in this case segment
ids form a block diagonal matrix so at least one element in each row is set.
It is easy to break this condition with non-self-attention configurations.
Attributes:
q: segment ids along the Q sequence
kv: segment ids along the KV sequence
"""

q: jax.Array | jax.sharding.PartitionSpec # [q_seq_len]
kv: jax.Array | jax.sharding.PartitionSpec # [kv_seq_len]


# Return type of SplashAttention function that implements the custom vjp rule.
SplashCustomReturnType: TypeAlias = jax.Array | tuple[jax.Array, dict[str, jax.Array]]

SplashResidualsType = tuple[
jax.Array, # q
jax.Array, # k
jax.Array, # v
SegmentIds | None, # segment_ids
jax.Array | None, # sinks
jax.Array, # out
jax.Array, # logsumexp
MaskInfo | None, # dkv_mask_info
]


def _attention_reference_impl(
q: jax.Array,
k: jax.Array,
v: jax.Array,
mask: jax.Array,
segment_ids: SegmentIds | None,
sinks: jax.Array | None,
mask_value: float,
save_residuals: bool,
attn_logits_soft_cap: float | None,
) -> SplashCustomReturnType:
logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32))

if segment_ids is not None:
mask = jnp.logical_and(mask, segment_ids.q[:, None] == segment_ids.kv[None, :])

if attn_logits_soft_cap is not None:
logits = jnp.tanh(logits / attn_logits_soft_cap)
logits = logits * attn_logits_soft_cap

if sinks is not None:
assert sinks.shape == () # should already be vmapped

logits = jnp.where(mask, logits, mask_value)
m = logits.max(axis=-1)
sinks = None if sinks is None else sinks.astype(logits.dtype)
m = m if sinks is None else jnp.maximum(m, sinks)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1) + (0 if sinks is None else jnp.exp(sinks - m))
p = s / l[..., None]

o = jnp.einsum("st,td->sd", p, v.astype(jnp.float32))

if save_residuals:
logsumexp = m + jnp.log(l)
return o, {"logsumexp": logsumexp, "max_logits": m}
return o


def _attention_reference_custom_bwd(
do,
q,
k,
v,
mask,
segment_ids,
sinks,
o,
logsumexp,
mask_value: float = DEFAULT_MASK_VALUE,
backward_impl: str = "vanilla",
attn_logits_soft_cap: float | None = None,
) -> tuple[jax.Array, jax.Array, jax.Array, None, None, jax.Array | None]:
uncapped_logits = jnp.einsum("qc,kc->qk", q, k, preferred_element_type=jnp.float32)

if attn_logits_soft_cap is not None:
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
logits = logits * attn_logits_soft_cap
else:
logits = uncapped_logits

if segment_ids is not None:
mask = jnp.logical_and(mask, segment_ids.q[:, None] == segment_ids.kv[None, :])
logits = jnp.where(mask, logits, mask_value)

p = jnp.exp(logits - logsumexp[..., None])
do = do.astype(jnp.float32) # pytype: disable=attribute-error
dv = jnp.einsum("pt,pd->td", p, do).astype(v.dtype)
dp = jnp.einsum("pd,td->pt", do, v.astype(jnp.float32))

# These two ways of computing ds are mathematically equivalent. The first
# involves reducing over the head_dim dimension and the second involves
# reducing over a sequence dimension. They tend to produce slightly different
# numerics.
if backward_impl == "flash":
di = jnp.sum(o.astype(jnp.float32) * do, axis=-1)[..., None]
else:
di = jnp.einsum("st,st->s", dp, p)[:, None]
ds = (dp - di) * p
if attn_logits_soft_cap is not None:
normalized = uncapped_logits / attn_logits_soft_cap
d = jnp.tanh(normalized)
g = ds * (1 - d)
ds = g + g * d
dk = jnp.einsum("sd,st->td", q.astype(jnp.float32), ds).astype(k.dtype)
dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype)
dsinks = None
if sinks is not None:
sinks_exp = -jnp.exp(sinks[..., None, None].astype(jnp.float32) - logsumexp[..., None].astype(jnp.float32))
dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2))
return dq, dk, dv, None, None, dsinks


@functools.partial(
jax.jit,
static_argnames=[
"mask_value",
"save_residuals",
"attn_logits_soft_cap",
"is_mqa",
],
)
def attention_reference(
q: jax.Array,
k: jax.Array,
v: jax.Array,
mask: jax.Array,
segment_ids: SegmentIds | None = None,
sinks: jax.Array | None = None,
*,
is_mqa: bool,
mask_value: float = DEFAULT_MASK_VALUE,
save_residuals: bool = False,
attn_logits_soft_cap: float | None = None,
):
"""A JIT-compiled reference implementation of attention, handles MQA and MHA."""
attn_impl = functools.partial(
_attention_reference_impl,
mask_value=mask_value,
save_residuals=save_residuals,
attn_logits_soft_cap=attn_logits_soft_cap,
)

if is_mqa:
func = jax.vmap(attn_impl, in_axes=(0, None, None, None, None, 0))
else:
# In grouped attention (1 < num_kv_heads && num_kv_heads < num_q_heads).
# We interleave the KV heads across the Q heads.
# For example: for 8 Q heads and 4 KV heads:
# Q head [0, 1] see KV head 0
# Q head [2, 3] see KV head 1
# Q head [4, 5] see KV head 2
# Q head [6, 7] see KV head 3

kv_heads, q_heads = k.shape[0], q.shape[0]
assert q_heads % kv_heads == 0

if kv_heads < q_heads:
# Repeat K and V heads to match the number of Q heads.
q_heads_per_kv = q_heads // kv_heads
k = jnp.repeat(k, repeats=q_heads_per_kv, axis=0)
v = jnp.repeat(v, repeats=q_heads_per_kv, axis=0)

func = jax.vmap(attn_impl, in_axes=(0, 0, 0, None, None, 0))

out = func(q, k, v, mask, segment_ids, sinks)
return out


@functools.partial(jax.jit, static_argnames=["is_mqa", "backward_impl", "attn_logits_soft_cap"])
def attention_reference_vjp(
do,
q,
k,
v,
mask,
segment_ids,
sinks,
o,
logsumexp,
*,
is_mqa: bool,
backward_impl: str = "vanilla",
attn_logits_soft_cap: float | None = None,
):
"""Wrapper for backward reference that handles GQA/MQA broadcasting and reduction."""
bwd = functools.partial(
_attention_reference_custom_bwd,
backward_impl=backward_impl,
attn_logits_soft_cap=attn_logits_soft_cap,
)

num_q_heads = q.shape[0]
num_kv_heads = 1 if is_mqa else k.shape[0]

is_grouped = not is_mqa and num_kv_heads < num_q_heads
assert num_q_heads % num_kv_heads == 0
head_multiplier = num_q_heads // num_kv_heads
if is_mqa:
bwd = jax.vmap(bwd, in_axes=(0, 0, None, None, None, None, 0, 0, 0))
else:
bwd = jax.vmap(bwd, in_axes=(0, 0, 0, 0, None, None, 0, 0, 0))
# Interleave the KV heads to match the corresponding Q heads.
if is_grouped:
k = jnp.repeat(k, head_multiplier, axis=0)
v = jnp.repeat(v, head_multiplier, axis=0)

dq, dk, dv, _, _, dsinks = bwd(do, q, k, v, mask, segment_ids, sinks, o, logsumexp)

if is_mqa:
dk, dv = dk.sum(axis=0), dv.sum(axis=0)
elif is_grouped:
# Perform the sum reduction across the head_multiplier dimension only.
# So that the output still has KV heads.
dk = dk.reshape(num_kv_heads, head_multiplier, *dk.shape[1:])
dv = dv.reshape(num_kv_heads, head_multiplier, *dv.shape[1:])
dk, dv = dk.sum(axis=1), dv.sum(axis=1)

return dq, dk, dv, dsinks
Binary file not shown.
Loading
Loading