Skip to content

Commit 8297cd6

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Also update attention.h to type-erased query_norm_scale
PiperOrigin-RevId: 825004308
1 parent 3cc0139 commit 8297cd6

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

gemma/attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV(
130130
void SingleDotSoftmaxWeightedSum(
131131
const size_t pos, const size_t start_pos, const size_t last_pos,
132132
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
133-
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
133+
const MatPtr& query_norm_scale, const size_t layer_idx,
134134
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
135135
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
136136
const float att_cap = activations.config.att_cap;
@@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
169169
}
170170

171171
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
172-
const MatPtrT<float>& query_norm_scale,
172+
const MatPtr& query_norm_scale,
173173
AttentionActivationsPtrs& activations,
174174
QBatch& qbatch, ThreadingContext& ctx) {
175175
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);

gemma/attention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ namespace gcpp {
3838
void SingleDotSoftmaxWeightedSum( \
3939
const size_t pos, const size_t start_pos, const size_t last_pos, \
4040
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
41-
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
41+
const MatPtr& query_norm_scale, size_t layer_idx, \
4242
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
4343
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
4444
\
4545
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
46-
const MatPtrT<float>& query_norm_scale, \
46+
const MatPtr& query_norm_scale, \
4747
AttentionActivationsPtrs& activations, \
4848
QBatch& qbatch, ThreadingContext& ctx); \
4949
\

0 commit comments

Comments
 (0)