From 968a5aa87f4dc533c8ba1d77e0170ba39b4fa49a Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Mon, 9 Mar 2026 03:25:28 -0700 Subject: [PATCH] Change to FastExpMinusOrZero PiperOrigin-RevId: 880761639 --- gemma/flash_attention.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 4f4336e8..61c5b326 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos, } float m = hn::ReduceMax(df, x); m = std::max(m, old_max); - x = hn::Exp(df, hn::Sub(x, hn::Set(df, m))); + x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m))); float scale = old_d * std::exp(old_max - m); old_d = hn::ReduceSum(df, x) + scale; old_max = m; @@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos, float m = hn::ReduceMax(df, x_max); m = std::max(m, old_max); VF m_vec = hn::Set(df, m); - x0 = hn::Exp(df, hn::Sub(x0, m_vec)); - x1 = hn::Exp(df, hn::Sub(x1, m_vec)); + x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec)); + x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec)); float scale = old_d * std::exp(old_max - m); VF x_sum = hn::Add(x0, x1); old_d = hn::ReduceSum(df, x_sum) + scale; @@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, [](auto a, auto b) HWY_ATTR { return hn::Add(a, b); }); } - VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max))); + VF4 scale = hn::Mul( + old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max))); old_d_vf = hn::Add(scale, x_sum); auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f)); const VF zero = hn::Zero(df); @@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( x_6_sum, x_7_sum, [](auto a, auto b) HWY_ATTR { return hn::Add(a, b); }); } - VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max))); + VF8 scale = hn::Mul( + old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max))); old_d_vf = hn::Add(scale, x_sum); auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f)); const VF zero = hn::Zero(df);