@@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
7373}
7474
7575void PositionalEncodingQK (float * qk, const size_t layer_idx,
76- const LayerWeightsPtrs& layer,
77- const AttentionActivations& activations,
76+ const AttentionActivationsPtrs& activations,
7877 ThreadingContext& ctx, const size_t worker,
7978 const size_t pos, const float mul) {
80- const size_t qkv_dim = layer.layer_config .qkv_dim ;
81- const PostQKType& post_qk = layer.layer_config .post_qk ;
79+ const LayerConfig& layer_config = activations.config .layer_configs [layer_idx];
80+ const size_t qkv_dim = layer_config.qkv_dim ;
81+ const PostQKType& post_qk = layer_config.post_qk ;
8282 // qk is either q or k, so qkv_dim is the length we operate on.
8383 const float * inv_timescale = activations.inv_timescale .PackedScale1 ();
8484 const bool is_global_layer = activations.config .IsGlobalLayer (layer_idx);
@@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
130130void SingleDotSoftmaxWeightedSum (
131131 const size_t pos, const size_t start_pos, const size_t last_pos,
132132 float * HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
133- const size_t layer_idx , const LayerWeightsPtrs& layer ,
134- const AttentionActivations & activations, float * HWY_RESTRICT att,
133+ const MatPtrT< float >& query_norm_scale , const size_t layer_idx ,
134+ const AttentionActivationsPtrs & activations, float * HWY_RESTRICT att,
135135 float * HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
136136 const float att_cap = activations.config .att_cap ;
137137 const float query_scale = activations.query_scale ;
138138 const size_t seq_len =
139139 static_cast <size_t >(activations.div_seq_len .GetDivisor ());
140-
140+ const LayerConfig& layer_config = activations. config . layer_configs [layer_idx];
141141 // Apply rope and scaling to Q.
142- if (layer. query_norm_scale .HasPtr ()) {
143- CallUpcasted (&layer. query_norm_scale , [&](const auto * weights_t ) {
142+ if (query_norm_scale.HasPtr ()) {
143+ CallUpcasted (&query_norm_scale, [&](const auto * weights_t ) {
144144 RMSNormInplace (weights_t ->PackedScale1 (), /* w_ofs=*/ 0 , q,
145- layer. layer_config .qkv_dim , ctx, worker);
145+ layer_config.qkv_dim , ctx, worker);
146146 });
147147 }
148148
149- PositionalEncodingQK (q, layer_idx, layer, activations, ctx, worker, pos,
149+ PositionalEncodingQK (q, layer_idx, activations, ctx, worker, pos,
150150 query_scale);
151151
152152 QDotK (start_pos, last_pos, activations.div_seq_len , q, k, att, ctx, worker);
@@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
169169}
170170
171171void DotSoftmaxWeightedSum (const size_t num_tokens, const size_t layer_idx,
172- const LayerWeightsPtrs& layer ,
173- AttentionActivations & activations, QBatch& qbatch ,
174- ThreadingContext& ctx) {
172+ const MatPtrT< float >& query_norm_scale ,
173+ AttentionActivationsPtrs & activations,
174+ QBatch& qbatch, ThreadingContext& ctx) {
175175 GCPP_ZONE (ctx, 0 , Zones::kGenAttentionDotSoftmaxWeightedSumInclusive );
176176
177177 const hwy::Divisor div_qbatch (qbatch.Size ());
178- const LayerConfig& layer_config = layer. layer_config ;
178+ const LayerConfig& layer_config = activations. config . layer_configs [layer_idx] ;
179179 const size_t qkv_dim = layer_config.qkv_dim ;
180180
181181 // A "head group" in the context of GQA refers to a collection of query
@@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
223223 MatPtrT<KV_t> v (" v_view" , Extents2D (seq_len, qkv_dim));
224224 v.SetPtr (kv_cache.Row (0 ) + kv_head_offset + qkv_dim, kv_cache.Stride ());
225225
226- SingleDotSoftmaxWeightedSum (pos, start_pos, last_pos, q, k, v, layer_idx,
227- layer, activations, att, att_out, ctx, worker);
226+ SingleDotSoftmaxWeightedSum (pos, start_pos, last_pos, q, k, v,
227+ query_norm_scale, layer_idx, activations, att,
228+ att_out, ctx, worker);
228229 };
229230
230231 {
@@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
245246// Fills activations.q and writes to KV cache.
246247static HWY_INLINE void ComputeQKV (size_t num_tokens, const size_t layer_idx,
247248 const LayerWeightsPtrs& layer,
248- AttentionActivations & activations,
249+ AttentionActivationsPtrs & activations,
249250 const QBatch& qbatch, const int flags,
250251 MatMulEnv& env) {
251252 GCPP_ZONE (env.ctx , hwy::Profiler::GlobalIdx (),
@@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
312313 });
313314 }
314315
315- PositionalEncodingQK (kv_f32, layer_idx, layer, activations, env.ctx ,
316- worker, pos, /* mul=*/ 1 .0f );
316+ PositionalEncodingQK (kv_f32, layer_idx, activations, env.ctx , worker ,
317+ pos, /* mul=*/ 1 .0f );
317318 CompressPerThread tls;
318319 Compress (kv_f32, 2 * qkv_dim, tls, MakeSpan (kv, 2 * qkv_dim), 0 );
319320 });
@@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
322323// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
323324// head_dim (`qkv_dim`) into output (`layer_out`).
324325static HWY_INLINE void SumHeads (const LayerWeightsPtrs& layer,
325- AttentionActivations & activations,
326+ AttentionActivationsPtrs & activations,
326327 MatMulEnv& env) {
327328 GCPP_ZONE (env.ctx , hwy::Profiler::GlobalIdx (), Zones::kGenAttentionSumHeads );
328329 const LayerConfig& layer_config = layer.layer_config ;
@@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
340341
341342void GemmaAttention (size_t num_tokens, const size_t layer_idx,
342343 const LayerWeightsPtrs& layer,
343- AttentionActivations & activations, QBatch& qbatch,
344+ AttentionActivationsPtrs & activations, QBatch& qbatch,
344345 MatMulEnv& env, int flags) {
345346 GCPP_ZONE (env.ctx , hwy::Profiler::GlobalIdx (), Zones::kGenAttention );
346347
@@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
352353
353354 ComputeQKV (num_tokens, layer_idx, layer, activations, qbatch, flags, env);
354355 if (flags & kAttentionUseOld ) {
355- DotSoftmaxWeightedSum (num_tokens, layer_idx, layer, activations, qbatch ,
356- env.ctx );
356+ DotSoftmaxWeightedSum (num_tokens, layer_idx, layer. query_norm_scale ,
357+ activations, qbatch, env.ctx );
357358 } else {
358359 // * 2 does not help on Turin.
359360 FlashAttention (num_tokens,
360361 /* target_parallelism=*/ env.ctx .pools .MaxWorkers () * 1 ,
361- layer_idx, layer, activations, qbatch, env.ctx );
362+ layer_idx, layer.query_norm_scale , activations, qbatch,
363+ env.ctx );
362364 }
363365 SumHeads (layer, activations, env);
364366}
0 commit comments