@@ -232,8 +232,8 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
232232 const size_t num_interleaved = num_tokens * num_queries;
233233
234234 // Self extend
235- constexpr size_t ngb_size = TConfig::self_extend_ngb_size ;
236- constexpr size_t grp_size = TConfig::self_extend_grp_size ;
235+ constexpr size_t ngb_size = TConfig::kSelfExtendNgbSize ;
236+ constexpr size_t grp_size = TConfig::kSelfExtendGrpSize ;
237237
238238 // For the computation of Q, K, and V, it is useful to remember that
239239 // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
@@ -298,8 +298,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
298298 float * HWY_RESTRICT kv = kv_cache.kv_cache .get () + kv_offset;
299299
300300 // When embedding position, we will use grouped key position
301- if (pos > ngb_size && TConfig::kSelfExtend ) {
302- pos /= grp_size;
301+ if constexpr (TConfig::kSelfExtend ) {
302+ if (pos > ngb_size) {
303+ pos /= grp_size;
304+ }
303305 }
304306 if constexpr (kIsMHA ) {
305307 // For MHA, copy KV into the KV cache from scratch space (see above).
@@ -331,11 +333,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
331333
332334 // Apply rope and scaling to Q.
333335 size_t pos = batch_start + batch_idx;
334- if (pos > ngb_size && TConfig::kSelfExtend ) {
335- const grp_pos = pos / grp_size;
336- const shift = ngb_size - ngb_size / grp_size
337- const shifted_grouped_pos = grp_pos + shift
338- pos = shifted_grouped_pos;
336+ if constexpr (TConfig::kSelfExtend ) {
337+ if (pos > ngb_size) {
338+ const size_t grp_pos = pos / grp_size;
339+ const size_t shift = ngb_size - ngb_size / grp_size;
340+ const size_t shifted_grouped_pos = grp_pos + shift;
341+ pos = shifted_grouped_pos;
342+ }
339343 }
340344 PostQK<TConfig>(q, pos, layer);
341345 MulByConst (kQueryScale , q, kQKVDim );
0 commit comments