2121#include < limits>
2222
2323#include " compression/types.h" // GEMMA_DISABLED_TARGETS
24+ #include " gemma/flash_structs.h"
2425#include " util/threading_context.h"
2526#include " util/zones.h"
2627#ifndef HWY_DISABLED_TARGETS
@@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
444445// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
445446// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
446447// max_last_pos].
447- void TileFlashAttention4 (const MatPtrT<float >& q,
448- const uint32_t * HWY_RESTRICT q_offsets,
449- const MatPtrT<KV_t>& k, const size_t start_pos,
450- const uint32_t * HWY_RESTRICT last_pos,
451- const size_t min_last_pos, const size_t max_last_pos,
452- const MatPtrT<KV_t>& v, const size_t layer_idx,
453- const AttentionActivationsPtrs& activations,
454- MatPtrT<float >& att_out,
455- const uint32_t * HWY_RESTRICT out_offsets,
456- ThreadingContext& ctx, const size_t worker) {
448+ Tile4FlashState TileFlashAttention4 (
449+ const MatPtrT<float >& q, const uint32_t * HWY_RESTRICT q_offsets,
450+ const MatPtrT<KV_t>& k, const size_t start_pos,
451+ const uint32_t * HWY_RESTRICT last_pos, const size_t min_last_pos,
452+ const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
453+ const AttentionActivationsPtrs& activations, MatPtrT<float >& att_out,
454+ const uint32_t * HWY_RESTRICT out_offsets, ThreadingContext& ctx,
455+ const size_t worker) {
457456 GCPP_ZONE (ctx, worker, Zones::kFlashAttentionTileFlashAttention4 );
458457 using DF = hn::ScalableTag<float >;
459458 const DF df;
@@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT<float>& q,
467466 hwy::ZeroBytes (att_out.Row (0 ) + out_offsets[i],
468467 v.Cols () * sizeof (att_out.Row (0 )[0 ]));
469468 }
470- float old_m0 = -std::numeric_limits<float >::max () / 2 .0f ;
471- float old_m1 = -std::numeric_limits<float >::max () / 2 .0f ;
472- float old_m2 = -std::numeric_limits<float >::max () / 2 .0f ;
473- float old_m3 = -std::numeric_limits<float >::max () / 2 .0f ;
474- float old_d0 = 0 .0f ;
475- float old_d1 = 0 .0f ;
476- float old_d2 = 0 .0f ;
477- float old_d3 = 0 .0f ;
469+ Tile4FlashState state;
478470 size_t position = start_pos;
479471 while (position + kHTileSize - 1 <= min_last_pos) {
480472 int32_t k_offsets[kMaxNF ];
@@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
494486 x2 = hn::Mul (cap, hn::Tanh (df, hn::Mul (x2, one_over_cap)));
495487 x3 = hn::Mul (cap, hn::Tanh (df, hn::Mul (x3, one_over_cap)));
496488 }
497- scales[0 ] = SingleFlashAttentionRowVector (df, x0, old_m0, old_d0);
498- scales[1 ] = SingleFlashAttentionRowVector (df, x1, old_m1, old_d1);
499- scales[2 ] = SingleFlashAttentionRowVector (df, x2, old_m2, old_d2);
500- scales[3 ] = SingleFlashAttentionRowVector (df, x3, old_m3, old_d3);
489+ scales[0 ] = SingleFlashAttentionRowVector (df, x0, state.row_states [0 ].max ,
490+ state.row_states [0 ].d );
491+ scales[1 ] = SingleFlashAttentionRowVector (df, x1, state.row_states [1 ].max ,
492+ state.row_states [1 ].d );
493+ scales[2 ] = SingleFlashAttentionRowVector (df, x2, state.row_states [2 ].max ,
494+ state.row_states [2 ].d );
495+ scales[3 ] = SingleFlashAttentionRowVector (df, x3, state.row_states [3 ].max ,
496+ state.row_states [3 ].d );
501497 MulByConstAndAddTile4 (df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row (0 ),
502498 out_offsets, v.Cols ());
503499 position += kHTileSize ;
@@ -516,7 +512,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
516512 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
517513 float x0 =
518514 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
519- SingleFlashAttentionStep (x0, activations.config .att_cap , old_m0, old_d0,
515+ SingleFlashAttentionStep (x0, activations.config .att_cap ,
516+ state.row_states [0 ].max , state.row_states [0 ].d ,
520517 v.Row (k_pos), v.Cols (),
521518 att_out.Row (0 ) + out_offsets[0 ]);
522519 }
@@ -526,7 +523,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
526523 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
527524 float x1 =
528525 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
529- SingleFlashAttentionStep (x1, activations.config .att_cap , old_m1, old_d1,
526+ SingleFlashAttentionStep (x1, activations.config .att_cap ,
527+ state.row_states [1 ].max , state.row_states [1 ].d ,
530528 v.Row (k_pos), v.Cols (),
531529 att_out.Row (0 ) + out_offsets[1 ]);
532530 }
@@ -536,7 +534,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
536534 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
537535 float x2 =
538536 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
539- SingleFlashAttentionStep (x2, activations.config .att_cap , old_m2, old_d2,
537+ SingleFlashAttentionStep (x2, activations.config .att_cap ,
538+ state.row_states [2 ].max , state.row_states [2 ].d ,
540539 v.Row (k_pos), v.Cols (),
541540 att_out.Row (0 ) + out_offsets[2 ]);
542541 }
@@ -546,12 +545,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
546545 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
547546 float x3 =
548547 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
549- SingleFlashAttentionStep (x3, activations.config .att_cap , old_m3, old_d3,
548+ SingleFlashAttentionStep (x3, activations.config .att_cap ,
549+ state.row_states [3 ].max , state.row_states [3 ].d ,
550550 v.Row (k_pos), v.Cols (),
551551 att_out.Row (0 ) + out_offsets[3 ]);
552552 }
553553 ++position;
554554 }
555+ return state;
555556}
556557
557558// Rounds n to a number that can be used as the number of Q rows in a tile
0 commit comments