Skip to content
Open
4 changes: 4 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ struct LayerConfig {
size_t conv1d_width = 0;
bool ff_biases = false;
bool softmax_attn_output_biases = false;
bool self_extend = false;
size_t ngb_size = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this n-gram block? Maybe expand to block_size for more clarity? We can also move these three new fields into a section (just newline before them) with a // Self-extension comment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jan-wassenberg Sorry, didn't understood it. I did it here because LayerConfig gets accessed during the Attention mechanism.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be unclear, I was suggesting considering renaming this to ngram_block_size.
And it would be good to add a newline plus "// Self-extension" comment for visual separation from the other fields in this struct.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ngb is short for neighbour, i see the point of confusion now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah :) Generally it's good to write out words.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood, i'll make the change.

size_t grp_size = 1;

PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;
Expand Down
24 changes: 21 additions & 3 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,29 @@ class GemmaAttention {
const size_t interleaved_idx = task / layer_config_.kv_heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ +
head * layer_config_.qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];

const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;

float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
layer_config_.qkv_dim;

// When embedding position, we will use grouped key position
if (self_extend && pos > ngb_size) {
pos /= grp_size;
}
// Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);

// If MHA, also copy V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
Expand Down Expand Up @@ -411,12 +419,22 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2;

const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;

// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
if (self_extend && pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
PositionalEncodingQK(q, pos, layer_, query_scale, q);

const size_t start_pos = StartPos(pos, layer_);
Expand Down
Loading