@@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
131131}
132132
133133// Handles a single v row of flash attention for a single q.k dot product.
134- void HWY_INLINE SingleFlashAttentionStep (
135- float x, float cap, float & old_max, float & old_d,
136- const float * HWY_RESTRICT v, const size_t v_cols,
137- float * HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
134+ void HWY_INLINE SingleFlashAttentionStep (float x, float cap, float & old_max,
135+ float & old_d,
136+ const float * HWY_RESTRICT v,
137+ const size_t v_cols,
138+ float * HWY_RESTRICT att_out) {
138139 if (cap > 0 .0f ) {
139140 // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
140141 x = cap * std::tanh (x / cap);
@@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep(
147148 float one_over_d = 1 .0f / old_d;
148149 scale *= one_over_d;
149150 x *= one_over_d;
150- MulByConst (scale, att_out, v_cols, p, worker );
151- MulByConstAndAdd (x, v, att_out, v_cols, p, worker );
151+ MulByConst (scale, att_out, v_cols);
152+ MulByConstAndAdd (x, v, att_out, v_cols);
152153}
153154
154155// Calculates the complete attention outputs for a single row of q.
@@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
174175 const size_t pos_mod = activations.div_seq_len .Remainder (pos);
175176 float x = Dot (q, k.Row (pos_mod), k.Cols ());
176177 SingleFlashAttentionStep (x, activations.config .att_cap , m, d,
177- v.Row (pos_mod), v.Cols (), att_out, p, worker );
178+ v.Row (pos_mod), v.Cols (), att_out);
178179 }
179180}
180181
@@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
183184template <class DF , class VF = hn::Vec<DF>>
184185VF QDotKVector (DF df, const uint32_t * HWY_RESTRICT q_offsets,
185186 const size_t k_pos, const MatPtrT<KV_t>& q,
186- const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker ) {
187+ const MatPtrT<KV_t>& k) {
187188 hn::TFromD<DF> results[hn::MaxLanes (df)];
188189 for (size_t i = 0 ; i < hn::Lanes (df); ++i) {
189190 results[i] = Dot (q.Row (0 ) + q_offsets[i], k.Row (k_pos), k.Cols ());
@@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
198199// consecutive elements, and other columns by adding q_stride.
199200template <class DF , class VF = hn::Vec<DF>>
200201void QDotKTileFloat (DF df, const float * HWY_RESTRICT q, const size_t q_stride,
201- const MatPtrT<KV_t>& k, const size_t * k_pos,
202- hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
203- VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
202+ const MatPtrT<KV_t>& k, const size_t * k_pos, VF& sum0,
203+ VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
204204 VF& sum7) {
205205 constexpr size_t kHTileSize = kNFx8HTileSize ;
206206 sum0 = hn::Zero (df);
@@ -303,8 +303,8 @@ void TileFlashAttention(
303303 k_pos[i] = activations.div_seq_len .Remainder (position + i);
304304 }
305305 VF x0, x1, x2, x3, x4, x5, x6, x7;
306- QDotKTileFloat (df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3,
307- x4, x5, x6, x7);
306+ QDotKTileFloat (df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6 ,
307+ x7);
308308 if (activations.config .att_cap > 0 .0f ) {
309309 // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
310310 VF cap = hn::Set (df, activations.config .att_cap );
@@ -343,12 +343,12 @@ void TileFlashAttention(
343343 x6 = hn::Mul (x6, one_over_d);
344344 x7 = hn::Mul (x7, one_over_d);
345345 MulByConstAndAddTile (df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
346- att_out.Row (0 ), out_offsets, v.Cols (), p, worker );
346+ att_out.Row (0 ), out_offsets, v.Cols ());
347347 position += kHTileSize ;
348348 }
349349 while (position <= max_last_pos) {
350350 size_t k_pos = activations.div_seq_len .Remainder (position);
351- VF x0 = QDotKVector (df, q_offsets, k_pos, q, k, p, worker );
351+ VF x0 = QDotKVector (df, q_offsets, k_pos, q, k);
352352 if (activations.config .att_cap > 0 .0f ) {
353353 // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
354354 VF cap = hn::Set (df, activations.config .att_cap );
@@ -369,7 +369,7 @@ void TileFlashAttention(
369369 x0 = hn::Mul (x0, one_over_d);
370370 scale = hn::Mul (scale, one_over_d);
371371 MulByConstAndAddVector (df, scale, x0, v, k_pos, att_out.Row (0 ), out_offsets,
372- v.Cols (), p, worker );
372+ v.Cols ());
373373 ++position;
374374 }
375375}
@@ -380,8 +380,8 @@ void TileFlashAttention(
380380template <class DF , class VF = hn::Vec<DF>>
381381void QDotKTilex4 (DF df, const float * HWY_RESTRICT q,
382382 const uint32_t * HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
383- const int32_t * HWY_RESTRICT k_offsets, hwy::Profiler& p ,
384- const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
383+ const int32_t * HWY_RESTRICT k_offsets, VF& sum0, VF& sum1 ,
384+ VF& sum2, VF& sum3) {
385385 sum0 = hn::Zero (df);
386386 sum1 = hn::Zero (df);
387387 sum2 = hn::Zero (df);
@@ -462,8 +462,7 @@ void TileFlashAttention4(
462462 k_offsets[i] = k.Row (v_pos[i]) - k.Row (0 );
463463 }
464464 VF x0, x1, x2, x3;
465- QDotKTilex4 (df, q.Row (0 ), q_offsets, k, k_offsets, p, worker, x0, x1, x2,
466- x3);
465+ QDotKTilex4 (df, q.Row (0 ), q_offsets, k, k_offsets, x0, x1, x2, x3);
467466 if (activations.config .att_cap > 0 .0f ) {
468467 // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
469468 VF cap = hn::Set (df, activations.config .att_cap );
@@ -478,7 +477,7 @@ void TileFlashAttention4(
478477 scales[2 ] = SingleFlashAttentionRowVector (df, x2, old_m2, old_d2);
479478 scales[3 ] = SingleFlashAttentionRowVector (df, x3, old_m3, old_d3);
480479 MulByConstAndAddTile4 (df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row (0 ),
481- out_offsets, v.Cols (), p, worker );
480+ out_offsets, v.Cols ());
482481 position += kHTileSize ;
483482 }
484483 while (position <= max_last_pos) {
@@ -488,28 +487,28 @@ void TileFlashAttention4(
488487 float x0 = Dot (q.Row (0 ) + q_offsets[0 ], k.Row (k_pos), k.Cols ());
489488 SingleFlashAttentionStep (x0, activations.config .att_cap , old_m0, old_d0,
490489 v.Row (k_pos), v.Cols (),
491- att_out.Row (0 ) + out_offsets[0 ], p, worker );
490+ att_out.Row (0 ) + out_offsets[0 ]);
492491 }
493492 if (position <= last_pos[1 ]) {
494493 // Past the last position, x1 doesn't count.
495494 float x1 = Dot (q.Row (0 ) + q_offsets[1 ], k.Row (k_pos), k.Cols ());
496495 SingleFlashAttentionStep (x1, activations.config .att_cap , old_m1, old_d1,
497496 v.Row (k_pos), v.Cols (),
498- att_out.Row (0 ) + out_offsets[1 ], p, worker );
497+ att_out.Row (0 ) + out_offsets[1 ]);
499498 }
500499 if (position <= last_pos[2 ]) {
501500 // Past the last position, x2 doesn't count.
502501 float x2 = Dot (q.Row (0 ) + q_offsets[2 ], k.Row (k_pos), k.Cols ());
503502 SingleFlashAttentionStep (x2, activations.config .att_cap , old_m2, old_d2,
504503 v.Row (k_pos), v.Cols (),
505- att_out.Row (0 ) + out_offsets[2 ], p, worker );
504+ att_out.Row (0 ) + out_offsets[2 ]);
506505 }
507506 if (position <= last_pos[3 ]) {
508507 // Past the last position, x3 doesn't count.
509508 float x3 = Dot (q.Row (0 ) + q_offsets[3 ], k.Row (k_pos), k.Cols ());
510509 SingleFlashAttentionStep (x3, activations.config .att_cap , old_m3, old_d3,
511510 v.Row (k_pos), v.Cols (),
512- att_out.Row (0 ) + out_offsets[3 ], p, worker );
511+ att_out.Row (0 ) + out_offsets[3 ]);
513512 }
514513 ++position;
515514 }
0 commit comments