@@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV(
130130void SingleDotSoftmaxWeightedSum (
131131 const size_t pos, const size_t start_pos, const size_t last_pos,
132132 float * HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
133- const MatPtrT< float > & query_norm_scale, const size_t layer_idx,
133+ const MatPtr & query_norm_scale, const size_t layer_idx,
134134 const AttentionActivationsPtrs& activations, float * HWY_RESTRICT att,
135135 float * HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
136136 const float att_cap = activations.config .att_cap ;
@@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
169169}
170170
171171void DotSoftmaxWeightedSum (const size_t num_tokens, const size_t layer_idx,
172- const MatPtrT< float > & query_norm_scale,
172+ const MatPtr & query_norm_scale,
173173 AttentionActivationsPtrs& activations,
174174 QBatch& qbatch, ThreadingContext& ctx) {
175175 GCPP_ZONE (ctx, 0 , Zones::kGenAttentionDotSoftmaxWeightedSumInclusive );
0 commit comments