Skip to content

Commit 5cb8d67

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Add int8 quantization stats
Compute the L1 error and Shannon SNR (higher is better). PiperOrigin-RevId: 846699792
1 parent 11aa16a commit 5cb8d67

File tree

2 files changed

+157
-57
lines changed

2 files changed

+157
-57
lines changed

gemma/tensor_stats.cc

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,98 @@ void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
107107
bytes_per_row * row_idx);
108108
}
109109

110+
constexpr size_t kGroupSize = 128; // subchannel
111+
112+
void QuantizeGroup(const float* HWY_RESTRICT in,
113+
TensorStatsAccumulator& my_stats) {
114+
namespace hn = hwy::HWY_NAMESPACE;
115+
const hn::ScalableTag<float> df;
116+
using VF = hn::Vec<decltype(df)>;
117+
using MF = hn::Mask<decltype(df)>;
118+
const hn::ScalableTag<double> dd;
119+
using VD = hn::Vec<decltype(dd)>;
120+
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
121+
HWY_ALIGN float enc[kGroupSize];
122+
HWY_ALIGN float dec[kGroupSize];
123+
HWY_ALIGN float all_snr[kGroupSize];
124+
HWY_DASSERT(kGroupSize % NF == 0); // No remainder handling required.
125+
126+
const VF k0 = hn::Zero(df);
127+
const VF k1 = hn::Set(df, 1.0f);
128+
129+
// Scan for min/max for quantization.
130+
VF vmin = hn::Set(df, hwy::HighestValue<float>());
131+
VF vmax = hn::Set(df, hwy::LowestValue<float>());
132+
for (size_t i = 0; i < kGroupSize; i += NF) {
133+
const VF v = hn::Load(df, in + i);
134+
vmin = hn::Min(vmin, v);
135+
vmax = hn::Max(vmax, v);
136+
}
137+
const float min = hn::ReduceMin(df, vmin);
138+
const float max = hn::ReduceMax(df, vmax);
139+
// Avoid division by zero during quantization.
140+
if (max == min) return;
141+
142+
// Distortion stats.
143+
VF vsum_err = hn::Zero(df);
144+
VD sum_log_snr0 = hn::Zero(dd);
145+
VD sum_log_snr1 = hn::Zero(dd);
146+
size_t num_snr = 0;
147+
148+
// Unclipped asymmetric quantization (for activations).
149+
const VF scale = hn::Set(df, 255.0f / (max - min));
150+
const VF inv_scale = hn::Div(k1, scale);
151+
const VF zeropoint = hn::Sub(hn::Round(hn::Mul(hn::Set(df, -min), scale)),
152+
hn::Set(df, 128.0f));
153+
const VF dq_sub = hn::Mul(zeropoint, inv_scale); // For MulSub.
154+
for (size_t i = 0; i < kGroupSize; i += NF) {
155+
const VF v = hn::Load(df, in + i);
156+
const VF q = hn::Round(hn::MulAdd(v, scale, zeropoint));
157+
hn::Store(q, df, enc + i);
158+
// Dequantize.
159+
const VF d = hn::MulSub(q, inv_scale, dq_sub);
160+
hn::Store(d, df, dec + i);
161+
162+
const VF err = hn::AbsDiff(v, d); // L1
163+
vsum_err = hn::Add(vsum_err, err);
164+
165+
// For preventing division by zero. However, we still want to
166+
// clamp snr because it could be very high (>1E3 when most
167+
// elements are lossless).
168+
const MF has_err = hn::Gt(err, k0);
169+
const VF rel = hn::MaskedDivOr(k0, has_err, hn::Abs(v), err);
170+
// SNR = 1 + abs/L1, with cap on the latter term.
171+
const VF snr = hn::Add(k1, hn::Min(rel, hn::Set(df, 300.f)));
172+
hn::Store(snr, df, all_snr + i);
173+
// Where `has_err` is false, `snr` elements are 1 and log(1) is zero, hence
174+
// they do not affect sum_log. However, very high errors also result in
175+
// snr=1, which drags down the average because `sum_log` is increased.
176+
num_snr += hn::CountTrue(df, has_err);
177+
178+
const VD log_snr0 = hn::Log(dd, hn::PromoteLowerTo(dd, snr));
179+
const VD log_snr1 = hn::Log(dd, hn::PromoteUpperTo(dd, snr));
180+
sum_log_snr0 = hn::Add(sum_log_snr0, log_snr0);
181+
sum_log_snr1 = hn::Add(sum_log_snr1, log_snr1);
182+
}
183+
184+
const float sum_err = hn::ReduceSum(df, vsum_err);
185+
const float avg_L1 = sum_err / static_cast<float>(kGroupSize);
186+
const double sum_log = hn::ReduceSum(dd, hn::Add(sum_log_snr0, sum_log_snr1));
187+
// SNR >= 1, hence log >= 0.
188+
HWY_ASSERT(sum_log >= 0.0);
189+
if (num_snr == 0) { // Avoid division by zero.
190+
// It can happen that dequantization is lossless, i.e. SNR is
191+
// infinite; skip such groups.
192+
HWY_ASSERT(sum_err == 0.0f);
193+
return;
194+
}
195+
// Signal to noise ratio (Shannon's channel capacity, NOT the
196+
// L2-based and logarithmic PSNR)
197+
const float snr = std::exp(sum_log / static_cast<double>(num_snr));
198+
199+
my_stats.NotifyGroup(avg_L1, snr);
200+
}
201+
110202
// First dispatch to the type, then parallel over rows, then vectorized
111203
// decompress and Notify for each value.
112204
void UpdateStatsT(TensorStats& stats, size_t layer_idx,
@@ -138,37 +230,38 @@ void UpdateStatsT(TensorStats& stats, size_t layer_idx,
138230
my_stats.NotifyCond(ConditionNumber(row, cols));
139231

140232
namespace hn = hwy::HWY_NAMESPACE;
141-
hn::ScalableTag<float> df;
233+
const hn::ScalableTag<float> df;
142234
using VF = hn::Vec<decltype(df)>;
143235
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
144-
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
236+
HWY_ALIGN float buf[kGroupSize];
237+
size_t buf_filled = 0;
145238

146239
size_t packed_ofs = 0;
147240
if (cols >= 2 * NF) {
148241
for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
149242
VF v0, v1;
150243
Decompress2(df, packed, packed_ofs, v0, v1);
151-
hn::Store(v0, df, buf);
152-
hn::Store(v1, df, buf + NF);
153-
const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1));
154-
const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1));
155-
const float min = hn::ReduceMin(df, min_mag);
156-
if (min != 0.0f) { // Avoid division by zero.
157-
my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag));
158-
}
244+
hn::Store(v0, df, buf + buf_filled);
245+
hn::Store(v1, df, buf + buf_filled + NF);
246+
buf_filled += 2 * NF;
247+
if (buf_filled == kGroupSize) {
248+
QuantizeGroup(buf, my_stats);
249+
250+
for (size_t i = 0; i < kGroupSize; ++i) {
251+
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
252+
}
253+
my_stats.NotifyCorr(Correlation(buf, kGroupSize));
159254

160-
for (size_t i = 0; i < 2 * NF; ++i) {
161-
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
255+
buf_filled = 0;
162256
}
163-
my_stats.NotifyCorr(Correlation(buf, 2 * NF));
164257
}
165258
}
166259

167260
// Zero to two vectors remaining.
168261
for (; packed_ofs < cols; packed_ofs += NF) {
169262
const size_t remaining = HWY_MIN(NF, cols - packed_ofs);
170263
DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining);
171-
// Skip NotifyGroup for this partial group.
264+
// Skip QuantizeGroup because it requires full groups.
172265
for (size_t i = 0; i < remaining; ++i) {
173266
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
174267
}

gemma/tensor_stats.h

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ struct TensorStatsAcrossLayers {
6868
fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str());
6969
}
7070
fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str());
71-
fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str());
71+
fprintf(stderr, "err_avg %s\n", s_grp_err_avg.ToString(skip).c_str());
72+
fprintf(stderr, "err_std %s\n", s_grp_err_std.ToString(skip).c_str());
73+
fprintf(stderr, "err_max %s\n", s_grp_err_max.ToString(skip).c_str());
74+
fprintf(stderr, "snr_1 %s\n", s_grp_snr1.ToString(skip).c_str());
75+
fprintf(stderr, "snr_avg %s\n", s_grp_snr_avg.ToString(skip).c_str());
76+
fprintf(stderr, "snr_std %s\n", s_grp_snr_std.ToString(skip).c_str());
7277
fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str());
7378
fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str());
7479
fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str());
@@ -112,7 +117,12 @@ struct TensorStatsAcrossLayers {
112117
hwy::Stats s_corr_avg;
113118
hwy::Stats s_corr_max;
114119

115-
hwy::Stats s_range_avg;
120+
hwy::Stats s_grp_err_avg;
121+
hwy::Stats s_grp_err_std;
122+
hwy::Stats s_grp_err_max;
123+
hwy::Stats s_grp_snr1;
124+
hwy::Stats s_grp_snr_avg;
125+
hwy::Stats s_grp_snr_std;
116126

117127
hwy::Stats s_exp_min;
118128
hwy::Stats s_exp_max;
@@ -151,13 +161,11 @@ class TensorStatsAccumulator {
151161
void DoNotPrint() { skip_.fetch_or(1); }
152162
bool ShouldPrint() const { return skip_.load() == 0; }
153163

154-
// Vector code computed the min/max of a group (= two vectors); this is
155-
// faster than doing it in `Notify`.
156-
void NotifyGroup(float min, float max) {
157-
s_group_min_.Notify(min);
158-
s_group_max_.Notify(max);
159-
// Caller ensures min != 0.
160-
s_group_range_.Notify(max / min);
164+
// Computed by vector code, much faster than doing it in `Notify`.
165+
void NotifyGroup(float avg_L1, float snr) {
166+
s_group_err_.Notify(avg_L1);
167+
s_group_snr_.Notify(snr);
168+
num_snr1_ += (snr == 1.0f);
161169
}
162170

163171
void NotifyCorr(float corr) { s_corr_.Notify(corr); }
@@ -173,9 +181,9 @@ class TensorStatsAccumulator {
173181
s_val_.Assimilate(other.s_val_);
174182
s_mag_.Assimilate(other.s_mag_);
175183
s_corr_.Assimilate(other.s_corr_);
176-
s_group_min_.Assimilate(other.s_group_min_);
177-
s_group_max_.Assimilate(other.s_group_max_);
178-
s_group_range_.Assimilate(other.s_group_range_);
184+
s_group_err_.Assimilate(other.s_group_err_);
185+
s_group_snr_.Assimilate(other.s_group_snr_);
186+
num_snr1_ += other.num_snr1_;
179187
}
180188

181189
// Called on the per-layer representative after reducing across threads.
@@ -197,7 +205,12 @@ class TensorStatsAccumulator {
197205
s.s_corr_avg.Notify(s_corr_.Mean());
198206
s.s_corr_max.Notify(s_corr_.Max());
199207

200-
s.s_range_avg.Notify(s_group_range_.Mean());
208+
s.s_grp_err_avg.Notify(s_group_err_.Mean());
209+
s.s_grp_err_std.Notify(s_group_err_.StandardDeviation());
210+
s.s_grp_err_max.Notify(s_group_err_.Max());
211+
s.s_grp_snr1.Notify(static_cast<float>(num_snr1_));
212+
s.s_grp_snr_avg.Notify(s_group_snr_.Mean());
213+
s.s_grp_snr_std.Notify(s_group_snr_.StandardDeviation());
201214

202215
const uint32_t subnormals = b_exp256_.Bin(0);
203216
// Prevent subnormals from hiding the min exponent.
@@ -222,13 +235,12 @@ class TensorStatsAccumulator {
222235
void PrintAll() {
223236
fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_));
224237
const int skip = hwy::Stats::kNoGeomean;
225-
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
226-
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
227-
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
228-
fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str());
229-
fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str());
230-
fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str());
231-
fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str());
238+
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
239+
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
240+
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
241+
fprintf(stderr, "crr %s\n", s_corr_.ToString(skip).c_str());
242+
fprintf(stderr, "err %s\n", s_group_err_.ToString(skip).c_str());
243+
fprintf(stderr, "snr %s\n", s_group_snr_.ToString(skip).c_str());
232244
b_exp256_.Print("exp");
233245
PrintBinRanges(b_big_row_, "big row");
234246
PrintBinRanges(b_big_col_, "big col");
@@ -244,30 +256,25 @@ class TensorStatsAccumulator {
244256
}
245257
if (total == 0) return;
246258

247-
// If all bins are at least 10% of a uniform distribution, print the range
248-
// to vastly reduce the log size.
259+
fprintf(stderr, "%s total %zu: \n", name, total);
260+
// Group together runs to reduce the log size.
249261
const size_t min = HWY_MAX(1, total / (N * 10));
250-
size_t last = 0;
251-
for (; last < N; ++last) {
252-
if (b.Bin(last) < min) break;
253-
}
254-
if (last >= N / 2) {
255-
// Also require all subsequent bins to be zero, otherwise we should
256-
// print the outlier bins.
257-
bool all_zero = true;
258-
for (size_t i = last + 1; i < N; ++i) {
259-
if (b.Bin(last) != 0) {
260-
all_zero = false;
261-
break;
262-
}
262+
for (size_t i = 0; i < N; ++i) {
263+
if (b.Bin(i) == 0) continue;
264+
if (b.Bin(i) < min) {
265+
fprintf(stderr, " %3zu: %zu\n", i, b.Bin(i));
266+
continue;
263267
}
264-
if (all_zero) {
265-
fprintf(stderr, "%s: uniform up to %zu\n", name, last);
266-
return;
268+
const size_t first = i;
269+
while (i + 1 < N && b.Bin(i + 1) >= min) {
270+
i++;
271+
}
272+
if (first == i) {
273+
fprintf(stderr, " %3zu: %zu\n", i, b.Bin(i));
274+
} else {
275+
fprintf(stderr, " [%3zu, %3zu]\n", first, i);
267276
}
268277
}
269-
270-
b.Print(name, /*skip_zero=*/true);
271278
}
272279

273280
double sum_sq_ = 0.0; // for Frobenius norm
@@ -278,9 +285,9 @@ class TensorStatsAccumulator {
278285
hwy::Stats s_mag_;
279286
hwy::Stats s_cond_; // condition number
280287
hwy::Stats s_corr_; // lag-1 autocorrelation
281-
hwy::Stats s_group_min_;
282-
hwy::Stats s_group_max_;
283-
hwy::Stats s_group_range_;
288+
hwy::Stats s_group_err_;
289+
hwy::Stats s_group_snr_;
290+
size_t num_snr1_ = 0;
284291
std::atomic<int> skip_{0};
285292
};
286293

0 commit comments

Comments
 (0)