Skip to content

Commit ee18916

Browse files
theraysmithcopybara-github
authored andcommitted
Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead.
PiperOrigin-RevId: 819739402
1 parent e3e8511 commit ee18916

File tree

6 files changed

+43
-59
lines changed

6 files changed

+43
-59
lines changed

gemma/attention.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
8989
// PostQKType::Rope
9090
if (post_qk == PostQKType::HalfRope) {
9191
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
92-
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
92+
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
9393
} else {
9494
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
9595
}
@@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
113113
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
114114
worker);
115115
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
116-
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
116+
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
117117
}
118118
} else {
119119
{
@@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
122122
}
123123
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
124124
const size_t pos_mod = div_seq_len.Remainder(pos);
125-
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
126-
worker);
125+
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
127126
}
128127
}
129128
}

gemma/flash_attention.cc

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
131131
}
132132

133133
// Handles a single v row of flash attention for a single q.k dot product.
134-
void HWY_INLINE SingleFlashAttentionStep(
135-
float x, float cap, float& old_max, float& old_d,
136-
const float* HWY_RESTRICT v, const size_t v_cols,
137-
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
134+
void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
135+
float& old_d,
136+
const float* HWY_RESTRICT v,
137+
const size_t v_cols,
138+
float* HWY_RESTRICT att_out) {
138139
if (cap > 0.0f) {
139140
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
140141
x = cap * std::tanh(x / cap);
@@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep(
147148
float one_over_d = 1.0f / old_d;
148149
scale *= one_over_d;
149150
x *= one_over_d;
150-
MulByConst(scale, att_out, v_cols, p, worker);
151-
MulByConstAndAdd(x, v, att_out, v_cols, p, worker);
151+
MulByConst(scale, att_out, v_cols);
152+
MulByConstAndAdd(x, v, att_out, v_cols);
152153
}
153154

154155
// Calculates the complete attention outputs for a single row of q.
@@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
174175
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
175176
float x = Dot(q, k.Row(pos_mod), k.Cols());
176177
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
177-
v.Row(pos_mod), v.Cols(), att_out, p, worker);
178+
v.Row(pos_mod), v.Cols(), att_out);
178179
}
179180
}
180181

@@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
183184
template <class DF, class VF = hn::Vec<DF>>
184185
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
185186
const size_t k_pos, const MatPtrT<KV_t>& q,
186-
const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker) {
187+
const MatPtrT<KV_t>& k) {
187188
hn::TFromD<DF> results[hn::MaxLanes(df)];
188189
for (size_t i = 0; i < hn::Lanes(df); ++i) {
189190
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
@@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
198199
// consecutive elements, and other columns by adding q_stride.
199200
template <class DF, class VF = hn::Vec<DF>>
200201
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
201-
const MatPtrT<KV_t>& k, const size_t* k_pos,
202-
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
203-
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
202+
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
203+
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
204204
VF& sum7) {
205205
constexpr size_t kHTileSize = kNFx8HTileSize;
206206
sum0 = hn::Zero(df);
@@ -303,8 +303,8 @@ void TileFlashAttention(
303303
k_pos[i] = activations.div_seq_len.Remainder(position + i);
304304
}
305305
VF x0, x1, x2, x3, x4, x5, x6, x7;
306-
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3,
307-
x4, x5, x6, x7);
306+
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
307+
x7);
308308
if (activations.config.att_cap > 0.0f) {
309309
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
310310
VF cap = hn::Set(df, activations.config.att_cap);
@@ -343,12 +343,12 @@ void TileFlashAttention(
343343
x6 = hn::Mul(x6, one_over_d);
344344
x7 = hn::Mul(x7, one_over_d);
345345
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
346-
att_out.Row(0), out_offsets, v.Cols(), p, worker);
346+
att_out.Row(0), out_offsets, v.Cols());
347347
position += kHTileSize;
348348
}
349349
while (position <= max_last_pos) {
350350
size_t k_pos = activations.div_seq_len.Remainder(position);
351-
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker);
351+
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k);
352352
if (activations.config.att_cap > 0.0f) {
353353
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
354354
VF cap = hn::Set(df, activations.config.att_cap);
@@ -369,7 +369,7 @@ void TileFlashAttention(
369369
x0 = hn::Mul(x0, one_over_d);
370370
scale = hn::Mul(scale, one_over_d);
371371
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
372-
v.Cols(), p, worker);
372+
v.Cols());
373373
++position;
374374
}
375375
}
@@ -380,8 +380,8 @@ void TileFlashAttention(
380380
template <class DF, class VF = hn::Vec<DF>>
381381
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
382382
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
383-
const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p,
384-
const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
383+
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
384+
VF& sum2, VF& sum3) {
385385
sum0 = hn::Zero(df);
386386
sum1 = hn::Zero(df);
387387
sum2 = hn::Zero(df);
@@ -462,8 +462,7 @@ void TileFlashAttention4(
462462
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
463463
}
464464
VF x0, x1, x2, x3;
465-
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2,
466-
x3);
465+
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3);
467466
if (activations.config.att_cap > 0.0f) {
468467
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
469468
VF cap = hn::Set(df, activations.config.att_cap);
@@ -478,7 +477,7 @@ void TileFlashAttention4(
478477
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
479478
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
480479
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
481-
out_offsets, v.Cols(), p, worker);
480+
out_offsets, v.Cols());
482481
position += kHTileSize;
483482
}
484483
while (position <= max_last_pos) {
@@ -488,28 +487,28 @@ void TileFlashAttention4(
488487
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
489488
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
490489
v.Row(k_pos), v.Cols(),
491-
att_out.Row(0) + out_offsets[0], p, worker);
490+
att_out.Row(0) + out_offsets[0]);
492491
}
493492
if (position <= last_pos[1]) {
494493
// Past the last position, x1 doesn't count.
495494
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
496495
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
497496
v.Row(k_pos), v.Cols(),
498-
att_out.Row(0) + out_offsets[1], p, worker);
497+
att_out.Row(0) + out_offsets[1]);
499498
}
500499
if (position <= last_pos[2]) {
501500
// Past the last position, x2 doesn't count.
502501
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
503502
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
504503
v.Row(k_pos), v.Cols(),
505-
att_out.Row(0) + out_offsets[2], p, worker);
504+
att_out.Row(0) + out_offsets[2]);
506505
}
507506
if (position <= last_pos[3]) {
508507
// Past the last position, x3 doesn't count.
509508
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
510509
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
511510
v.Row(k_pos), v.Cols(),
512-
att_out.Row(0) + out_offsets[3], p, worker);
511+
att_out.Row(0) + out_offsets[3]);
513512
}
514513
++position;
515514
}

gemma/gemma.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
160160

161161
const size_t model_dim = model_config.model_dim;
162162
const float emb_scaling = EmbeddingScaling(model_dim);
163-
const size_t worker = 0; // Not yet parallelized.
164163

165164
HWY_DASSERT(token >= 0);
166165
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
@@ -176,8 +175,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
176175
const hn::ScalableTag<float> df;
177176
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
178177
model_dim);
179-
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim,
180-
ctx.profiler, worker);
178+
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim);
181179
});
182180

183181
if (model_config.absolute_pe) {

gemma/vit.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class VitAttention {
9595
float* HWY_RESTRICT q =
9696
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
9797
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
98-
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
98+
MulByConst(query_scale, q, qkv_dim);
9999
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
100100
});
101101

@@ -120,8 +120,7 @@ class VitAttention {
120120
for (size_t i = 0; i < seq_len; ++i) {
121121
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
122122
head * 3 * qkv_dim + 2 * qkv_dim;
123-
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
124-
env_.ctx.profiler, worker);
123+
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
125124
}
126125
});
127126
}
@@ -144,7 +143,7 @@ class VitAttention {
144143
// Compute Q.K scores, which are "logits" stored in head_att.
145144
float* HWY_RESTRICT q =
146145
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
147-
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
146+
MulByConst(query_scale, q, qkv_dim);
148147
float* HWY_RESTRICT head_att =
149148
activations_.attention.att.Row(token) + head * seq_len;
150149
for (size_t i = 0; i < seq_len; ++i) {
@@ -161,8 +160,7 @@ class VitAttention {
161160
for (size_t i = 0; i < seq_len; ++i) {
162161
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
163162
head * 3 * qkv_dim + 2 * qkv_dim;
164-
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
165-
env_.ctx.profiler, worker);
163+
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
166164
}
167165
});
168166
}

ops/ops-inl.h

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
560560

561561
template <typename XT>
562562
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
563-
const size_t size,
564-
hwy::Profiler& p,
565-
const size_t worker) {
566-
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst));
563+
const size_t size) {
567564
namespace hn = hwy::HWY_NAMESPACE;
568565
using DF = hn::ScalableTag<float>;
569566
using VF = hn::Vec<DF>;
@@ -596,10 +593,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
596593

597594
// out[i] += x[i] * c.
598595
template <typename XT, typename OT>
599-
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
600-
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
601-
const size_t size, hwy::Profiler& p, const size_t worker) {
602-
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd));
596+
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
597+
const XT* HWY_RESTRICT x,
598+
OT* HWY_RESTRICT out,
599+
const size_t size) {
603600
namespace hn = hwy::HWY_NAMESPACE;
604601
using DF = hn::ScalableTag<float>;
605602
using VF = hn::Vec<DF>;
@@ -734,9 +731,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
734731
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
735732
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
736733
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
737-
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
738-
hwy::Profiler& p, const size_t worker) {
739-
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile));
734+
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
740735
namespace hn = hwy::HWY_NAMESPACE;
741736
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
742737

@@ -996,9 +991,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
996991
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
997992
const VF c2, const VF c3, const MatPtrT<float>& v,
998993
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
999-
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
1000-
hwy::Profiler& p, const size_t worker) {
1001-
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4));
994+
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
1002995
namespace hn = hwy::HWY_NAMESPACE;
1003996
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
1004997

@@ -1037,9 +1030,7 @@ template <class DF, class VF = hn::Vec<DF>>
10371030
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
10381031
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
10391032
const size_t pos, float* HWY_RESTRICT out,
1040-
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
1041-
hwy::Profiler& p, const size_t worker) {
1042-
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector));
1033+
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
10431034
namespace hn = hwy::HWY_NAMESPACE;
10441035
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
10451036

@@ -1177,7 +1168,7 @@ static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
11771168
const float sum_exp = Sum(d, logits.data(), logits.size());
11781169
// Double-precision reciprocal does not appear to affect the results.
11791170
const float mul = 1.0f / sum_exp;
1180-
MulByConst(mul, logits.data(), logits.size(), p, worker);
1171+
MulByConst(mul, logits.data(), logits.size());
11811172
}
11821173

11831174
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /

ops/ops_test.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class TestMulByConstAndAdd {
183183

184184
SimpleMulByConstAndAdd(constant, o, e, count);
185185
InitProfilerZones(hwy::Profiler::Get());
186-
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
186+
MulByConstAndAdd(constant, o, x, count);
187187

188188
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
189189
__LINE__);
@@ -232,7 +232,7 @@ class TestMulByConst {
232232

233233
SimpleMulByConst(constant, e, count);
234234
InitProfilerZones(hwy::Profiler::Get());
235-
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
235+
MulByConst(constant, x, count);
236236

237237
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
238238
__LINE__);
@@ -443,7 +443,6 @@ void TestRopeAndMulBy() {
443443
ThreadingArgs threading_args;
444444
ThreadingContext ctx(threading_args);
445445
hwy::Profiler& p = ctx.profiler;
446-
InitProfilerZones(p);
447446
const size_t worker = 0;
448447

449448
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,

0 commit comments

Comments
 (0)