From b7d47860f224338eacaa8ca1e7d6b1a702dbb016 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 13 Apr 2026 11:57:21 +0100 Subject: [PATCH 1/2] fix(attention): allow `rel_pos_embedding` with `use_flash_attention` (#7997) Lift the hard `ValueError` that prevented combining `rel_pos_embedding` with `use_flash_attention=True` in `SABlock` and `CrossAttentionBlock`. When a relative-position bias (and/or causal mask) is present, build an additive attention bias and pass it via `attn_mask` to `torch.nn.functional.scaled_dot_product_attention`. With a null bias the no-mask fast path is preserved so PyTorch can still dispatch the true flash kernel; otherwise SDPA falls back to the memory-efficient or cuDNN backend, which both accept an additive float bias with working gradients. Replace the `ValueError` unit tests with numerical-equivalence tests against the explicit attention path for 2D and 3D `input_size`. Docstrings for `use_flash_attention` are updated to clarify that backend selection is delegated to SDPA. Signed-off-by: Soumya Snigdha Kundu --- monai/networks/blocks/crossattention.py | 35 ++++++++++++++--- monai/networks/blocks/selfattention.py | 40 ++++++++++++++++---- tests/networks/blocks/test_crossattention.py | 32 +++++++++++----- tests/networks/blocks/test_selfattention.py | 32 +++++++++++----- 4 files changed, 106 insertions(+), 33 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index baaa21ed1f..8c8081ca75 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -61,8 +61,11 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, dispatch attention through + ``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend; + the true flash kernel is used only when no attention bias is present. When combined + with ``rel_pos_embedding`` or ``causal``, PyTorch will fall back to the + memory-efficient or cuDNN SDPA backend. """ super().__init__() @@ -88,9 +91,6 @@ def __init__( "to True. save_attn can only be used if use_flash_attention is False" ) - if use_flash_attention and rel_pos_embedding is not None: - raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") - self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -155,8 +155,31 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): k = k.to(self.attention_dtype) if self.use_flash_attention: + # Additive bias path mirrors SABlock: null bias preserves the true + # flash kernel fast path; any of rel_pos_embedding / causal forces + # fallback to the efficient or cuDNN SDPA backend. + bias: torch.Tensor | None = None + lq, lk = q.shape[-2], k.shape[-2] + + if self.rel_positional_embedding is not None: + zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device) + bias = self.rel_positional_embedding(x, zero_logits, q) + + is_causal_arg = self.causal + if self.causal and bias is not None: + causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device) + causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf")) + bias = bias + causal_bias + is_causal_arg = False + x = torch.nn.functional.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + query=q, + key=k, + value=v, + attn_mask=bias, + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=is_causal_arg, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 2791d2fb00..c1fb385941 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -63,8 +63,11 @@ def __init__( attention_dtype: cast attention operations to this dtype. include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to True. - use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, dispatch attention through + ``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend; + the true flash kernel is used only when no attention bias is present. When combined + with ``rel_pos_embedding``, ``causal``, or ``attn_mask``, PyTorch will fall back to + the memory-efficient or cuDNN SDPA backend. """ @@ -94,9 +97,6 @@ def __init__( "to True. save_attn can only be used if use_flash_attention is False." ) - if use_flash_attention and rel_pos_embedding is not None: - raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") - self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj: nn.Linear | nn.Identity @@ -174,14 +174,40 @@ def forward(self, x, attn_mask: torch.Tensor | None = None): k = k.to(self.attention_dtype) if self.use_flash_attention: + # Build an additive attention bias when we have to combine + # rel_pos_embedding, a causal mask, or a user attn_mask. A null bias + # preserves the no-mask fast path so PyTorch can still pick the true + # flash kernel when available. + bias: torch.Tensor | None = None + lq, lk = q.shape[-2], k.shape[-2] + + if self.rel_positional_embedding is not None: + zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device) + bias = self.rel_positional_embedding(x, zero_logits, q) + + is_causal_arg = self.causal + if self.causal and (bias is not None or attn_mask is not None): + causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device) + causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf")) + bias = causal_bias if bias is None else bias + causal_bias + is_causal_arg = False + + if attn_mask is not None: + if self.causal: + raise ValueError("Causal attention does not support attention masks.") + mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype) + mask_bias.masked_fill_(attn_mask == 0, float("-inf")) + mask_bias = mask_bias.unsqueeze(1).unsqueeze(2) + bias = mask_bias if bias is None else bias + mask_bias + x = F.scaled_dot_product_attention( query=q, key=k, value=v, - attn_mask=attn_mask, + attn_mask=bias, scale=self.scale, dropout_p=self.dropout_rate, - is_causal=self.causal, + is_causal=is_causal_arg, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/tests/networks/blocks/test_crossattention.py b/tests/networks/blocks/test_crossattention.py index f691f4e534..68bac807ee 100644 --- a/tests/networks/blocks/test_crossattention.py +++ b/tests/networks/blocks/test_crossattention.py @@ -30,7 +30,7 @@ [ { **{k: v for k, v in params.items() if k not in ["rel_pos_embedding_val"]}, - "rel_pos_embedding": params["rel_pos_embedding_val"] if not params["use_flash_attention"] else None, + "rel_pos_embedding": params["rel_pos_embedding_val"], }, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"]), @@ -69,16 +69,28 @@ def test_save_attn_with_flash_attention(self): hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True ) + @skipUnless(has_einops, "Requires einops") def test_rel_pos_embedding_with_flash_attention(self): - with self.assertRaises(ValueError): - CrossAttentionBlock( - hidden_size=128, - num_heads=3, - dropout_rate=0.1, - use_flash_attention=True, - save_attn=False, - rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - ) + # rel_pos_embedding combined with use_flash_attention now dispatches + # via SDPA with an additive bias. Must match the explicit path. + for input_size in [(16, 32), (8, 8, 8)]: + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + seq_len = int(np.prod(input_size)) + test_data = torch.randn(2, seq_len, 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): diff --git a/tests/networks/blocks/test_selfattention.py b/tests/networks/blocks/test_selfattention.py index af52918612..85949498c5 100644 --- a/tests/networks/blocks/test_selfattention.py +++ b/tests/networks/blocks/test_selfattention.py @@ -43,7 +43,7 @@ "input_size": input_size, "include_fc": include_fc, "use_combined_linear": use_combined_linear, - "use_flash_attention": True if rel_pos_embedding is None else False, + "use_flash_attention": True, }, (2, 512, hidden_size), (2, 512, hidden_size), @@ -67,16 +67,28 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") def test_rel_pos_embedding_with_flash_attention(self): - with self.assertRaises(ValueError): - SABlock( - hidden_size=128, - num_heads=3, - dropout_rate=0.1, - use_flash_attention=True, - save_attn=False, - rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - ) + # rel_pos_embedding is now allowed with use_flash_attention; SDPA picks + # a fused backend that supports an additive attention bias. The two + # code paths must be numerically equivalent for the same weights. + for input_size in [(16, 32), (8, 8, 8)]: + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = SABlock(**input_param, use_flash_attention=True).to(device) + block_ref = SABlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, int(np.prod(input_size)), 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) def test_save_attn_with_flash_attention(self): with self.assertRaises(ValueError): From c5b2a1c34a566d41da1999e2b3f68e120689d4f9 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 4 May 2026 10:47:03 +0100 Subject: [PATCH 2/2] test(attention): cover merged causal- and attn_mask-bias flash branches Address CodeRabbit review on PR #8842: - Narrow the use_flash_attention docstring in SABlock and CrossAttentionBlock so it reflects the actual implementation: pure causal masking keeps the fast path via is_causal=True; only an additive bias (rel_pos_embedding, or causal/attn_mask merged with another bias) forces SDPA to fall back to the memory-efficient or cuDNN backend. - Extend the numerical-equivalence tests to cover the new merged-bias paths: causal=True + rel_pos_embedding for both blocks, and attn_mask + rel_pos_embedding for SABlock. All cases assert assert_allclose(out_flash, out_ref, atol=1e-4) on 2D and 3D inputs. Signed-off-by: Soumya Snigdha Kundu --- monai/networks/blocks/crossattention.py | 8 +-- monai/networks/blocks/selfattention.py | 8 +-- tests/networks/blocks/test_crossattention.py | 25 ++++++++++ tests/networks/blocks/test_selfattention.py | 51 ++++++++++++++++++++ 4 files changed, 86 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 8c8081ca75..ffcfd8b801 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -63,9 +63,11 @@ def __init__( attention_dtype: cast attention operations to this dtype. use_flash_attention: if True, dispatch attention through ``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend; - the true flash kernel is used only when no attention bias is present. When combined - with ``rel_pos_embedding`` or ``causal``, PyTorch will fall back to the - memory-efficient or cuDNN SDPA backend. + the true flash kernel is used when no custom additive attention bias is passed. + Pure ``causal`` masking (with no ``rel_pos_embedding``) keeps the fast path via + ``is_causal=True``. When an additive bias is required (for example, + ``rel_pos_embedding``, or ``causal`` merged with another bias), PyTorch falls + back to the memory-efficient or cuDNN SDPA backend. """ super().__init__() diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index c1fb385941..b03b237ba0 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -65,9 +65,11 @@ def __init__( use_combined_linear: whether to use a single linear layer for qkv projection, default to True. use_flash_attention: if True, dispatch attention through ``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend; - the true flash kernel is used only when no attention bias is present. When combined - with ``rel_pos_embedding``, ``causal``, or ``attn_mask``, PyTorch will fall back to - the memory-efficient or cuDNN SDPA backend. + the true flash kernel is used when no custom additive attention bias is passed. + Pure ``causal`` masking (with no ``rel_pos_embedding`` or ``attn_mask``) keeps the + fast path via ``is_causal=True``. When an additive bias is required (for example, + ``rel_pos_embedding``, or ``causal``/``attn_mask`` merged with another bias), + PyTorch falls back to the memory-efficient or cuDNN SDPA backend. """ diff --git a/tests/networks/blocks/test_crossattention.py b/tests/networks/blocks/test_crossattention.py index 68bac807ee..ebd39b92b7 100644 --- a/tests/networks/blocks/test_crossattention.py +++ b/tests/networks/blocks/test_crossattention.py @@ -92,6 +92,31 @@ def test_rel_pos_embedding_with_flash_attention(self): out_ref = block_ref(test_data) assert_allclose(out_flash, out_ref, atol=1e-4) + @skipUnless(has_einops, "Requires einops") + def test_causal_rel_pos_with_flash_attention(self): + # Exercise the merged causal-bias branch: causal=True together with + # rel_pos_embedding builds an additive bias and disables is_causal. + for input_size in [(16, 32), (8, 8, 8)]: + seq_len = int(np.prod(input_size)) + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + "causal": True, + "sequence_length": seq_len, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, seq_len, 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) + @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): diff --git a/tests/networks/blocks/test_selfattention.py b/tests/networks/blocks/test_selfattention.py index 85949498c5..1ff6398894 100644 --- a/tests/networks/blocks/test_selfattention.py +++ b/tests/networks/blocks/test_selfattention.py @@ -90,6 +90,57 @@ def test_rel_pos_embedding_with_flash_attention(self): out_ref = block_ref(test_data) assert_allclose(out_flash, out_ref, atol=1e-4) + @skipUnless(has_einops, "Requires einops") + def test_causal_rel_pos_with_flash_attention(self): + # Exercise the merged causal-bias branch: causal=True together with + # rel_pos_embedding builds an additive bias and disables is_causal, + # so flash and reference paths must still match numerically. + for input_size in [(16, 32), (8, 8, 8)]: + seq_len = int(np.prod(input_size)) + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + "causal": True, + "sequence_length": seq_len, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = SABlock(**input_param, use_flash_attention=True).to(device) + block_ref = SABlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, seq_len, 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) + + @skipUnless(has_einops, "Requires einops") + def test_attn_mask_rel_pos_with_flash_attention(self): + # Exercise the user-attn-mask + rel_pos branch: the user mask is + # merged into the additive bias passed via SDPA's attn_mask argument. + for input_size in [(16, 32), (8, 8, 8)]: + seq_len = int(np.prod(input_size)) + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = SABlock(**input_param, use_flash_attention=True).to(device) + block_ref = SABlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, seq_len, 128).to(device) + attn_mask = torch.ones(2, seq_len, dtype=torch.bool, device=device) + attn_mask[:, seq_len // 2 :] = False # mask out the second half + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data, attn_mask=attn_mask) + out_ref = block_ref(test_data, attn_mask=attn_mask) + assert_allclose(out_flash, out_ref, atol=1e-4) + def test_save_attn_with_flash_attention(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)