2121#include < limits>
2222
2323#include " compression/types.h" // GEMMA_DISABLED_TARGETS
24+ #include " gemma/flash_structs.h"
2425#include " util/threading_context.h"
2526#include " util/zones.h"
2627#ifndef HWY_DISABLED_TARGETS
@@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
444445// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
445446// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
446447// max_last_pos].
447- void TileFlashAttention4 (const MatPtrT<float >& q,
448- const uint32_t * HWY_RESTRICT q_offsets,
449- const MatPtrT<KV_t>& k, const size_t start_pos,
450- const uint32_t * HWY_RESTRICT last_pos,
451- const size_t min_last_pos, const size_t max_last_pos,
452- const MatPtrT<KV_t>& v, const size_t layer_idx,
453- const AttentionActivationsPtrs& activations,
454- MatPtrT<float >& att_out,
455- const uint32_t * HWY_RESTRICT out_offsets,
456- ThreadingContext& ctx, const size_t worker) {
448+ Tile4FlashParams TileFlashAttention4 (
449+ const MatPtrT<float >& q, const uint32_t * HWY_RESTRICT q_offsets,
450+ const MatPtrT<KV_t>& k, const size_t start_pos,
451+ const uint32_t * HWY_RESTRICT last_pos, const size_t min_last_pos,
452+ const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
453+ const AttentionActivationsPtrs& activations, MatPtrT<float >& att_out,
454+ const uint32_t * HWY_RESTRICT out_offsets, ThreadingContext& ctx,
455+ const size_t worker) {
457456 GCPP_ZONE (ctx, worker, Zones::kFlashAttentionTileFlashAttention4 );
458457 using DF = hn::ScalableTag<float >;
459458 const DF df;
@@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT<float>& q,
467466 hwy::ZeroBytes (att_out.Row (0 ) + out_offsets[i],
468467 v.Cols () * sizeof (att_out.Row (0 )[0 ]));
469468 }
470- float old_m0 = -std::numeric_limits<float >::max () / 2 .0f ;
471- float old_m1 = -std::numeric_limits<float >::max () / 2 .0f ;
472- float old_m2 = -std::numeric_limits<float >::max () / 2 .0f ;
473- float old_m3 = -std::numeric_limits<float >::max () / 2 .0f ;
474- float old_d0 = 0 .0f ;
475- float old_d1 = 0 .0f ;
476- float old_d2 = 0 .0f ;
477- float old_d3 = 0 .0f ;
469+ Tile4FlashParams params;
478470 size_t position = start_pos;
479471 while (position + kHTileSize - 1 <= min_last_pos) {
480472 int32_t k_offsets[kMaxNF ];
@@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
494486 x2 = hn::Mul (cap, hn::Tanh (df, hn::Mul (x2, one_over_cap)));
495487 x3 = hn::Mul (cap, hn::Tanh (df, hn::Mul (x3, one_over_cap)));
496488 }
497- scales[0 ] = SingleFlashAttentionRowVector (df, x0, old_m0, old_d0);
498- scales[1 ] = SingleFlashAttentionRowVector (df, x1, old_m1, old_d1);
499- scales[2 ] = SingleFlashAttentionRowVector (df, x2, old_m2, old_d2);
500- scales[3 ] = SingleFlashAttentionRowVector (df, x3, old_m3, old_d3);
489+ scales[0 ] = SingleFlashAttentionRowVector (df, x0, params.rff [0 ].max ,
490+ params.rff [0 ].d );
491+ scales[1 ] = SingleFlashAttentionRowVector (df, x1, params.rff [1 ].max ,
492+ params.rff [1 ].d );
493+ scales[2 ] = SingleFlashAttentionRowVector (df, x2, params.rff [2 ].max ,
494+ params.rff [2 ].d );
495+ scales[3 ] = SingleFlashAttentionRowVector (df, x3, params.rff [3 ].max ,
496+ params.rff [3 ].d );
501497 MulByConstAndAddTile4 (df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row (0 ),
502498 out_offsets, v.Cols ());
503499 position += kHTileSize ;
@@ -516,42 +512,43 @@ void TileFlashAttention4(const MatPtrT<float>& q,
516512 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
517513 float x0 =
518514 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
519- SingleFlashAttentionStep (x0, activations.config .att_cap , old_m0, old_d0,
520- v. Row (k_pos), v. Cols ( ),
521- att_out.Row (0 ) + out_offsets[0 ]);
515+ SingleFlashAttentionStep (x0, activations.config .att_cap ,
516+ params. rff [ 0 ]. max , params. rff [ 0 ]. d , v. Row (k_pos ),
517+ v. Cols (), att_out.Row (0 ) + out_offsets[0 ]);
522518 }
523519 if (position <= last_pos[1 ]) {
524520 // Past the last position, x1 doesn't count.
525521 CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[1 ],
526522 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
527523 float x1 =
528524 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
529- SingleFlashAttentionStep (x1, activations.config .att_cap , old_m1, old_d1,
530- v. Row (k_pos), v. Cols ( ),
531- att_out.Row (0 ) + out_offsets[1 ]);
525+ SingleFlashAttentionStep (x1, activations.config .att_cap ,
526+ params. rff [ 1 ]. max , params. rff [ 1 ]. d , v. Row (k_pos ),
527+ v. Cols (), att_out.Row (0 ) + out_offsets[1 ]);
532528 }
533529 if (position <= last_pos[2 ]) {
534530 // Past the last position, x2 doesn't count.
535531 CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[2 ],
536532 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
537533 float x2 =
538534 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
539- SingleFlashAttentionStep (x2, activations.config .att_cap , old_m2, old_d2,
540- v. Row (k_pos), v. Cols ( ),
541- att_out.Row (0 ) + out_offsets[2 ]);
535+ SingleFlashAttentionStep (x2, activations.config .att_cap ,
536+ params. rff [ 2 ]. max , params. rff [ 2 ]. d , v. Row (k_pos ),
537+ v. Cols (), att_out.Row (0 ) + out_offsets[2 ]);
542538 }
543539 if (position <= last_pos[3 ]) {
544540 // Past the last position, x3 doesn't count.
545541 CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[3 ],
546542 qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
547543 float x3 =
548544 Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
549- SingleFlashAttentionStep (x3, activations.config .att_cap , old_m3, old_d3,
550- v. Row (k_pos), v. Cols ( ),
551- att_out.Row (0 ) + out_offsets[3 ]);
545+ SingleFlashAttentionStep (x3, activations.config .att_cap ,
546+ params. rff [ 3 ]. max , params. rff [ 3 ]. d , v. Row (k_pos ),
547+ v. Cols (), att_out.Row (0 ) + out_offsets[3 ]);
552548 }
553549 ++position;
554550 }
551+ return params;
555552}
556553
557554// Rounds n to a number that can be used as the number of Q rows in a tile
0 commit comments