Skip to content

Commit eb4412f

Browse files
theraysmithcopybara-github
authored andcommitted
Added access to flash attention internals to TileFlashAttention4
PiperOrigin-RevId: 825973042
1 parent ee7d79c commit eb4412f

File tree

5 files changed

+87
-49
lines changed

5 files changed

+87
-49
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ cc_library(
543543
"gemma/activations.h",
544544
"gemma/attention.h",
545545
"gemma/flash_attention.h",
546+
"gemma/flash_structs.h",
546547
"gemma/gemma.h",
547548
"gemma/vit.h",
548549
],

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ set(SOURCES
8282
gemma/configs.h
8383
gemma/flash_attention.cc
8484
gemma/flash_attention.h
85+
gemma/flash_structs.h
8586
gemma/gemma_args.h
8687
gemma/gemma-inl.h
8788
gemma/gemma.cc

gemma/flash_attention.cc

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <limits>
2222

2323
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
24+
#include "gemma/flash_structs.h"
2425
#include "util/threading_context.h"
2526
#include "util/zones.h"
2627
#ifndef HWY_DISABLED_TARGETS
@@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
444445
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
445446
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
446447
// max_last_pos].
447-
void TileFlashAttention4(const MatPtrT<float>& q,
448-
const uint32_t* HWY_RESTRICT q_offsets,
449-
const MatPtrT<KV_t>& k, const size_t start_pos,
450-
const uint32_t* HWY_RESTRICT last_pos,
451-
const size_t min_last_pos, const size_t max_last_pos,
452-
const MatPtrT<KV_t>& v, const size_t layer_idx,
453-
const AttentionActivationsPtrs& activations,
454-
MatPtrT<float>& att_out,
455-
const uint32_t* HWY_RESTRICT out_offsets,
456-
ThreadingContext& ctx, const size_t worker) {
448+
Tile4FlashState TileFlashAttention4(
449+
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
450+
const MatPtrT<KV_t>& k, const size_t start_pos,
451+
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
452+
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
453+
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
454+
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
455+
const size_t worker) {
457456
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
458457
using DF = hn::ScalableTag<float>;
459458
const DF df;
@@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT<float>& q,
467466
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
468467
v.Cols() * sizeof(att_out.Row(0)[0]));
469468
}
470-
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
471-
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
472-
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
473-
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
474-
float old_d0 = 0.0f;
475-
float old_d1 = 0.0f;
476-
float old_d2 = 0.0f;
477-
float old_d3 = 0.0f;
469+
Tile4FlashState state;
478470
size_t position = start_pos;
479471
while (position + kHTileSize - 1 <= min_last_pos) {
480472
int32_t k_offsets[kMaxNF];
@@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
494486
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
495487
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
496488
}
497-
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
498-
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
499-
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
500-
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
489+
scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max,
490+
state.row_states[0].d);
491+
scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max,
492+
state.row_states[1].d);
493+
scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max,
494+
state.row_states[2].d);
495+
scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max,
496+
state.row_states[3].d);
501497
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
502498
out_offsets, v.Cols());
503499
position += kHTileSize;
@@ -516,7 +512,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
516512
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
517513
float x0 =
518514
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
519-
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
515+
SingleFlashAttentionStep(x0, activations.config.att_cap,
516+
state.row_states[0].max, state.row_states[0].d,
520517
v.Row(k_pos), v.Cols(),
521518
att_out.Row(0) + out_offsets[0]);
522519
}
@@ -526,7 +523,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
526523
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
527524
float x1 =
528525
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
529-
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
526+
SingleFlashAttentionStep(x1, activations.config.att_cap,
527+
state.row_states[1].max, state.row_states[1].d,
530528
v.Row(k_pos), v.Cols(),
531529
att_out.Row(0) + out_offsets[1]);
532530
}
@@ -536,7 +534,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
536534
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
537535
float x2 =
538536
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
539-
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
537+
SingleFlashAttentionStep(x2, activations.config.att_cap,
538+
state.row_states[2].max, state.row_states[2].d,
540539
v.Row(k_pos), v.Cols(),
541540
att_out.Row(0) + out_offsets[2]);
542541
}
@@ -546,12 +545,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
546545
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
547546
float x3 =
548547
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
549-
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
548+
SingleFlashAttentionStep(x3, activations.config.att_cap,
549+
state.row_states[3].max, state.row_states[3].d,
550550
v.Row(k_pos), v.Cols(),
551551
att_out.Row(0) + out_offsets[3]);
552552
}
553553
++position;
554554
}
555+
return state;
555556
}
556557

557558
// Rounds n to a number that can be used as the number of Q rows in a tile

gemma/flash_attention.h

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,47 @@
2020

2121
#include <stddef.h>
2222

23+
#include <cstdint>
24+
25+
#include "gemma/flash_structs.h"
2326
#include "gemma/gemma.h"
2427
#include "hwy/highway.h"
2528

2629
namespace gcpp {
2730

2831
// Passed to HWY_VISIT_TARGETS; declares for one target.
29-
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
30-
namespace NAMESPACE { \
31-
void RMSNormAndPositionalEncoding( \
32-
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
33-
const MatPtr& query_norm_scale, size_t layer_idx, \
34-
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
35-
\
36-
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
37-
const float* HWY_RESTRICT q, \
38-
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
39-
size_t layer_idx, \
40-
const AttentionActivationsPtrs& activations, \
41-
float* HWY_RESTRICT att_out, \
42-
ThreadingContext& ctx, size_t worker); \
43-
\
44-
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
45-
size_t total_tasks, size_t target_parallelism); \
46-
\
47-
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
48-
size_t layer_idx, const MatPtr& query_norm_scale, \
49-
AttentionActivationsPtrs& activations, QBatch& qbatch, \
50-
ThreadingContext& ctx); \
51-
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
32+
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
33+
namespace NAMESPACE { \
34+
void RMSNormAndPositionalEncoding( \
35+
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
36+
const MatPtr& query_norm_scale, size_t layer_idx, \
37+
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
38+
\
39+
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
40+
const float* HWY_RESTRICT q, \
41+
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
42+
size_t layer_idx, \
43+
const AttentionActivationsPtrs& activations, \
44+
float* HWY_RESTRICT att_out, \
45+
ThreadingContext& ctx, size_t worker); \
46+
\
47+
Tile4FlashState TileFlashAttention4( \
48+
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
49+
const MatPtrT<KV_t>& k, size_t start_pos, \
50+
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
51+
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
52+
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
53+
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
54+
ThreadingContext& ctx, const size_t worker); \
55+
\
56+
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
57+
size_t total_tasks, size_t target_parallelism); \
58+
\
59+
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
60+
size_t layer_idx, const MatPtr& query_norm_scale, \
61+
AttentionActivationsPtrs& activations, QBatch& qbatch, \
62+
ThreadingContext& ctx); \
63+
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
5264
} // namespace NAMESPACE
5365

5466
// Function declarations for each SIMD target. Allows direct call from the

gemma/flash_structs.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
2+
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
3+
4+
#include <stddef.h>
5+
6+
#include <limits>
7+
8+
namespace gcpp {
9+
10+
struct OnlineSoftmaxState {
11+
float max = -std::numeric_limits<float>::max() / 2.0f;
12+
float d = 0.0f;
13+
};
14+
15+
static constexpr size_t kVTileSize4 = 4;
16+
17+
struct Tile4FlashState {
18+
OnlineSoftmaxState row_states[kVTileSize4];
19+
};
20+
21+
} // namespace gcpp
22+
23+
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_

0 commit comments

Comments
 (0)