Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct AttentionActivations {
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);

// `inv_timescale*` are not batched.
}

MatStorageT<float> q; // query
Expand Down Expand Up @@ -137,6 +139,16 @@ struct AttentionActivationsPtrs {
inv_timescale_global = activations.inv_timescale_global;
}

void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
}

const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
Expand Down Expand Up @@ -203,6 +215,9 @@ struct Activations {
ffw_out.OverrideRows(batch_size);

attention_storage.SetBatchSize(batch_size);
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
// their row override is not updated when the underlying storage changes.
attention.SetBatchSize(batch_size);
}

const LayerConfig& layer_config;
Expand Down
5 changes: 2 additions & 3 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<float>& q,
const MatPtrT<float>& query_norm_scale,
const MatPtr& query_norm_scale,
const size_t layer_idx,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) {
Expand Down Expand Up @@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx,
const MatPtrT<float>& query_norm_scale,
const size_t layer_idx, const MatPtr& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
Expand Down
5 changes: 2 additions & 3 deletions gemma/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
Expand All @@ -45,8 +45,7 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
Expand Down
27 changes: 20 additions & 7 deletions ops/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,14 @@ class GenerateCandidates {

// The number of A and B columns to read between updating `C`.
SizeVec KC(size_t mr, MMOrder order) const {
// Must return the actual value: although ignored by `RangesOfKC`, this will
// be used in MC() and NC().
if (IsOneKC(order)) return SizeVec(1, K_);
if (IsOneKC(order)) {
// A single KC range is infeasible when K exceeds the max. The caller
// will skip all configs with `order`.
if (K_ > kMaxKC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfKC`, this
// will be used in MC() and NC().
return SizeVec(1, K_);
}
// `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(max_M_, mr);

Expand Down Expand Up @@ -227,13 +232,21 @@ class GenerateCandidates {

// The number of (L2 resident) A rows for `A2C0` to loop over.
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
// Must return the actual value: although ignored by `RangesOfMC`, this will
// be used in NC().
if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_);
if (max_M_ <= mr) return SizeVec(1, max_M_);
if (IsOneMC(order)) {
// A single MC range is infeasible when M exceeds the max. The caller
// will skip all configs with `order`.
if (max_M_ > kMaxMC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfMC`, this
// will be used in NC().
return SizeVec(1, max_M_);
}

// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
// it is typically inclusive.
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
// `kc` was chosen to fit in L1, hence this should not exceed L2.
HWY_ASSERT(bytes_b <= cache_.L2Bytes());

// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
Expand All @@ -242,7 +255,7 @@ class GenerateCandidates {
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
mc_max = HWY_MIN(mc_max, max_M_);
HWY_DASSERT(mc_max != 0);
HWY_ASSERT(mc_max != 0);

SizeVec all_mc;
all_mc.reserve(6);
Expand Down
2 changes: 1 addition & 1 deletion ops/ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
size_t cluster_idx = 0) {
HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out));
activations.DebugCheckSameShape(out);

CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
Expand Down
15 changes: 13 additions & 2 deletions util/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,15 @@ class MatPtr : public IFields {
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
bool SameShape(const MatPtr& other) const {
return Rows() == other.Rows() && cols_ == other.cols_;
return Rows() == other.Rows() && Cols() == other.Cols();
}
void DebugCheckSameShape(const MatPtr& other) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (!SameShape(other)) {
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
Rows(), Cols(), other.Rows(), Cols());
}
}
}
// Future calls to `Rows()` during this class' lifetime (not serialized)
// will return this value. Used to set the actual number of rows for
Expand Down Expand Up @@ -299,7 +307,10 @@ class MatPtrT : public MatPtr {
if (GetType() == Type::kUnknown) {
SetType(TypeEnum<MatT>());
} else {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
TypeName<MatT>(), TypeName(other.GetType()));
}
}
}
MatPtrT& operator=(const MatPtr& other) {
Expand Down
Loading