@@ -231,6 +231,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
231231 const size_t batch_start = interleaved_start / num_queries;
232232 const size_t num_interleaved = num_tokens * num_queries;
233233
234+ // Self extend
235+ constexpr size_t ngb_size = TConfig::self_extend_ngb_size;
236+ constexpr size_t grp_size = TConfig::self_extend_grp_size;
237+
234238 // For the computation of Q, K, and V, it is useful to remember that
235239 // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
236240 // and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
@@ -286,12 +290,17 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
286290 const size_t interleaved_idx = task / kKVHeads ;
287291 const size_t query_idx = interleaved_idx % num_queries;
288292 const size_t batch_idx = interleaved_idx / num_queries;
289- const size_t pos = batch_start + batch_idx;
293+ size_t pos = batch_start + batch_idx;
290294 const size_t cache_pos = div_seq_len.Remainder (pos);
291295 const size_t kv_offset = cache_pos * kCachePosSize +
292296 layer * kCacheLayerSize + head * kQKVDim * 2 ;
293297 KVCache& kv_cache = kv_caches[query_idx];
294298 float * HWY_RESTRICT kv = kv_cache.kv_cache .get () + kv_offset;
299+
300+ // When embedding position, we will use grouped key position
301+ if (pos > ngb_size && TConfig::kSelfExtend ) {
302+ pos /= grp_size;
303+ }
295304 if constexpr (kIsMHA ) {
296305 // For MHA, copy KV into the KV cache from scratch space (see above).
297306 const float * HWY_RESTRICT q =
@@ -321,7 +330,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
321330 activations.q .Batch (interleaved_idx) + head * kQStride ;
322331
323332 // Apply rope and scaling to Q.
324- const size_t pos = batch_start + batch_idx;
333+ size_t pos = batch_start + batch_idx;
334+ if (pos > ngb_size && TConfig::kSelfExtend ) {
335+ const grp_pos = pos / grp_size;
336+ const shift = ngb_size - ngb_size / grp_size
337+ const shifted_grouped_pos = grp_pos + shift
338+ pos = shifted_grouped_pos;
339+ }
325340 PostQK<TConfig>(q, pos, layer);
326341 MulByConst (kQueryScale , q, kQKVDim );
327342
0 commit comments