From 3cc0139ebbc3b3619cf427ab75047a80be67f42d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 28 Oct 2025 05:32:30 -0700 Subject: [PATCH] Fix excessive KC/MC from prior change This could lead to stack overflow in B_storage. Also do not require specific type for query_norm_scale, update batch sizes for attention tensors, more verbose Mat shape/type checks. PiperOrigin-RevId: 824987689 --- gemma/activations.h | 15 +++++++++++++++ gemma/flash_attention.cc | 5 ++--- gemma/flash_attention.h | 5 ++--- ops/matmul.cc | 27 ++++++++++++++++++++------- ops/ops-inl.h | 2 +- util/mat.h | 15 +++++++++++++-- 6 files changed, 53 insertions(+), 16 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 40320d8e..acaecb79 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -100,6 +100,8 @@ struct AttentionActivations { att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); + + // `inv_timescale*` are not batched. } MatStorageT q; // query @@ -137,6 +139,16 @@ struct AttentionActivationsPtrs { inv_timescale_global = activations.inv_timescale_global; } + void SetBatchSize(size_t batch_size) { + q.OverrideRows(batch_size); + // q_T rows are always qkv_dim! + pre_att_rms_out.OverrideRows(batch_size); + att.OverrideRows(batch_size); + att_out.OverrideRows(batch_size); + att_sums.OverrideRows(batch_size); + // `inv_timescale*` are not batched. + } + const ModelConfig& config; MatPtrT q; MatPtrT q_T; @@ -203,6 +215,9 @@ struct Activations { ffw_out.OverrideRows(batch_size); attention_storage.SetBatchSize(batch_size); + // `AttentionActivationsPtrs` holds `MatPtrT` which also require updating; + // their row override is not updated when the underlying storage changes. + attention.SetBatchSize(batch_size); } const LayerConfig& layer_config; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index b5dd2418..49cdfdc3 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, // Updates q in place for RMSNorm and positional encoding. void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, MatPtrT& q, - const MatPtrT& query_norm_scale, + const MatPtr& query_norm_scale, const size_t layer_idx, const AttentionActivationsPtrs& activations, ThreadingContext& ctx) { @@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, // grouped together so that mode 1 or 2 can be used, and choosing which of the // 3 modes to use for best efficiency. void FlashAttention(const size_t num_tokens, const size_t target_parallelism, - const size_t layer_idx, - const MatPtrT& query_norm_scale, + const size_t layer_idx, const MatPtr& query_norm_scale, AttentionActivationsPtrs& activations, QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 89af4984..ab3a3952 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -30,7 +30,7 @@ namespace gcpp { namespace NAMESPACE { \ void RMSNormAndPositionalEncoding( \ size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ - const MatPtrT& query_norm_scale, size_t layer_idx, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \ @@ -45,8 +45,7 @@ namespace gcpp { size_t total_tasks, size_t target_parallelism); \ \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, \ - const MatPtrT& query_norm_scale, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \ ThreadingContext& ctx); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ diff --git a/ops/matmul.cc b/ops/matmul.cc index 5a9fb0d2..c01943d1 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -175,9 +175,14 @@ class GenerateCandidates { // The number of A and B columns to read between updating `C`. SizeVec KC(size_t mr, MMOrder order) const { - // Must return the actual value: although ignored by `RangesOfKC`, this will - // be used in MC() and NC(). - if (IsOneKC(order)) return SizeVec(1, K_); + if (IsOneKC(order)) { + // A single KC range is infeasible when K exceeds the max. The caller + // will skip all configs with `order`. + if (K_ > kMaxKC) return SizeVec(); + // Must return the actual value: although ignored by `RangesOfKC`, this + // will be used in MC() and NC(). + return SizeVec(1, K_); + } // `LoopKC` handles up to `mr` rows of A. const size_t rows_a = HWY_MIN(max_M_, mr); @@ -227,13 +232,21 @@ class GenerateCandidates { // The number of (L2 resident) A rows for `A2C0` to loop over. SizeVec MC(size_t mr, size_t kc, MMOrder order) const { - // Must return the actual value: although ignored by `RangesOfMC`, this will - // be used in NC(). - if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_); + if (max_M_ <= mr) return SizeVec(1, max_M_); + if (IsOneMC(order)) { + // A single MC range is infeasible when M exceeds the max. The caller + // will skip all configs with `order`. + if (max_M_ > kMaxMC) return SizeVec(); + // Must return the actual value: although ignored by `RangesOfMC`, this + // will be used in NC(). + return SizeVec(1, max_M_); + } // Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because // it is typically inclusive. const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16)); + // `kc` was chosen to fit in L1, hence this should not exceed L2. + HWY_ASSERT(bytes_b <= cache_.L2Bytes()); // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // packed B. We want `mc * kc` elements of A to fit in L2, alongside @@ -242,7 +255,7 @@ class GenerateCandidates { size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC)); mc_max = HWY_MIN(mc_max, max_M_); - HWY_DASSERT(mc_max != 0); + HWY_ASSERT(mc_max != 0); SizeVec all_mc; all_mc.reserve(6); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index f2933cac..80183bba 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, size_t cluster_idx = 0) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == activations.Cols()); - HWY_DASSERT(activations.SameShape(out)); + activations.DebugCheckSameShape(out); CallUpcasted(&weights, [&](const auto* weights_t) { ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, diff --git a/util/mat.h b/util/mat.h index 6b8dc064..753e1941 100644 --- a/util/mat.h +++ b/util/mat.h @@ -181,7 +181,15 @@ class MatPtr : public IFields { Extents2D Extents() const { return Extents2D(Rows(), cols_); } bool IsEmpty() const { return Rows() == 0 || cols_ == 0; } bool SameShape(const MatPtr& other) const { - return Rows() == other.Rows() && cols_ == other.cols_; + return Rows() == other.Rows() && Cols() == other.Cols(); + } + void DebugCheckSameShape(const MatPtr& other) const { + if constexpr (HWY_IS_DEBUG_BUILD) { + if (!SameShape(other)) { + HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(), + Rows(), Cols(), other.Rows(), Cols()); + } + } } // Future calls to `Rows()` during this class' lifetime (not serialized) // will return this value. Used to set the actual number of rows for @@ -299,7 +307,10 @@ class MatPtrT : public MatPtr { if (GetType() == Type::kUnknown) { SetType(TypeEnum()); } else { - HWY_ASSERT(other.GetType() == TypeEnum()); + if (HWY_UNLIKELY(other.GetType() != TypeEnum())) { + HWY_ABORT("Type mismatch: MatT %s, constructing from %s", + TypeName(), TypeName(other.GetType())); + } } } MatPtrT& operator=(const MatPtr& other) {