Skip to content

Commit 3cc0139

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage. Also do not require specific type for query_norm_scale, update batch sizes for attention tensors, more verbose Mat shape/type checks. PiperOrigin-RevId: 824987689
1 parent 5a05857 commit 3cc0139

File tree

6 files changed

+53
-16
lines changed

6 files changed

+53
-16
lines changed

gemma/activations.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ struct AttentionActivations {
100100
att.OverrideRows(batch_size);
101101
att_out.OverrideRows(batch_size);
102102
att_sums.OverrideRows(batch_size);
103+
104+
// `inv_timescale*` are not batched.
103105
}
104106

105107
MatStorageT<float> q; // query
@@ -137,6 +139,16 @@ struct AttentionActivationsPtrs {
137139
inv_timescale_global = activations.inv_timescale_global;
138140
}
139141

142+
void SetBatchSize(size_t batch_size) {
143+
q.OverrideRows(batch_size);
144+
// q_T rows are always qkv_dim!
145+
pre_att_rms_out.OverrideRows(batch_size);
146+
att.OverrideRows(batch_size);
147+
att_out.OverrideRows(batch_size);
148+
att_sums.OverrideRows(batch_size);
149+
// `inv_timescale*` are not batched.
150+
}
151+
140152
const ModelConfig& config;
141153
MatPtrT<float> q;
142154
MatPtrT<float> q_T;
@@ -203,6 +215,9 @@ struct Activations {
203215
ffw_out.OverrideRows(batch_size);
204216

205217
attention_storage.SetBatchSize(batch_size);
218+
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
219+
// their row override is not updated when the underlying storage changes.
220+
attention.SetBatchSize(batch_size);
206221
}
207222

208223
const LayerConfig& layer_config;

gemma/flash_attention.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
9191
// Updates q in place for RMSNorm and positional encoding.
9292
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
9393
MatPtrT<float>& q,
94-
const MatPtrT<float>& query_norm_scale,
94+
const MatPtr& query_norm_scale,
9595
const size_t layer_idx,
9696
const AttentionActivationsPtrs& activations,
9797
ThreadingContext& ctx) {
@@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
592592
// grouped together so that mode 1 or 2 can be used, and choosing which of the
593593
// 3 modes to use for best efficiency.
594594
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
595-
const size_t layer_idx,
596-
const MatPtrT<float>& query_norm_scale,
595+
const size_t layer_idx, const MatPtr& query_norm_scale,
597596
AttentionActivationsPtrs& activations, QBatch& qbatch,
598597
ThreadingContext& ctx) {
599598
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);

gemma/flash_attention.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace gcpp {
3030
namespace NAMESPACE { \
3131
void RMSNormAndPositionalEncoding( \
3232
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
33-
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
33+
const MatPtr& query_norm_scale, size_t layer_idx, \
3434
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
3535
\
3636
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
@@ -45,8 +45,7 @@ namespace gcpp {
4545
size_t total_tasks, size_t target_parallelism); \
4646
\
4747
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
48-
size_t layer_idx, \
49-
const MatPtrT<float>& query_norm_scale, \
48+
size_t layer_idx, const MatPtr& query_norm_scale, \
5049
AttentionActivationsPtrs& activations, QBatch& qbatch, \
5150
ThreadingContext& ctx); \
5251
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \

ops/matmul.cc

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,14 @@ class GenerateCandidates {
175175

176176
// The number of A and B columns to read between updating `C`.
177177
SizeVec KC(size_t mr, MMOrder order) const {
178-
// Must return the actual value: although ignored by `RangesOfKC`, this will
179-
// be used in MC() and NC().
180-
if (IsOneKC(order)) return SizeVec(1, K_);
178+
if (IsOneKC(order)) {
179+
// A single KC range is infeasible when K exceeds the max. The caller
180+
// will skip all configs with `order`.
181+
if (K_ > kMaxKC) return SizeVec();
182+
// Must return the actual value: although ignored by `RangesOfKC`, this
183+
// will be used in MC() and NC().
184+
return SizeVec(1, K_);
185+
}
181186
// `LoopKC` handles up to `mr` rows of A.
182187
const size_t rows_a = HWY_MIN(max_M_, mr);
183188

@@ -227,13 +232,21 @@ class GenerateCandidates {
227232

228233
// The number of (L2 resident) A rows for `A2C0` to loop over.
229234
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
230-
// Must return the actual value: although ignored by `RangesOfMC`, this will
231-
// be used in NC().
232-
if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_);
235+
if (max_M_ <= mr) return SizeVec(1, max_M_);
236+
if (IsOneMC(order)) {
237+
// A single MC range is infeasible when M exceeds the max. The caller
238+
// will skip all configs with `order`.
239+
if (max_M_ > kMaxMC) return SizeVec();
240+
// Must return the actual value: although ignored by `RangesOfMC`, this
241+
// will be used in NC().
242+
return SizeVec(1, max_M_);
243+
}
233244

234245
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
235246
// it is typically inclusive.
236247
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
248+
// `kc` was chosen to fit in L1, hence this should not exceed L2.
249+
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
237250

238251
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
239252
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
@@ -242,7 +255,7 @@ class GenerateCandidates {
242255
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
243256
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
244257
mc_max = HWY_MIN(mc_max, max_M_);
245-
HWY_DASSERT(mc_max != 0);
258+
HWY_ASSERT(mc_max != 0);
246259

247260
SizeVec all_mc;
248261
all_mc.reserve(6);

ops/ops-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
497497
size_t cluster_idx = 0) {
498498
HWY_DASSERT(weights.Rows() == 1);
499499
HWY_DASSERT(weights.Cols() == activations.Cols());
500-
HWY_DASSERT(activations.SameShape(out));
500+
activations.DebugCheckSameShape(out);
501501

502502
CallUpcasted(&weights, [&](const auto* weights_t) {
503503
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,

util/mat.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,15 @@ class MatPtr : public IFields {
181181
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
182182
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
183183
bool SameShape(const MatPtr& other) const {
184-
return Rows() == other.Rows() && cols_ == other.cols_;
184+
return Rows() == other.Rows() && Cols() == other.Cols();
185+
}
186+
void DebugCheckSameShape(const MatPtr& other) const {
187+
if constexpr (HWY_IS_DEBUG_BUILD) {
188+
if (!SameShape(other)) {
189+
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
190+
Rows(), Cols(), other.Rows(), Cols());
191+
}
192+
}
185193
}
186194
// Future calls to `Rows()` during this class' lifetime (not serialized)
187195
// will return this value. Used to set the actual number of rows for
@@ -299,7 +307,10 @@ class MatPtrT : public MatPtr {
299307
if (GetType() == Type::kUnknown) {
300308
SetType(TypeEnum<MatT>());
301309
} else {
302-
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
310+
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
311+
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
312+
TypeName<MatT>(), TypeName(other.GetType()));
313+
}
303314
}
304315
}
305316
MatPtrT& operator=(const MatPtr& other) {

0 commit comments

Comments
 (0)