@@ -107,6 +107,98 @@ void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
107107 bytes_per_row * row_idx);
108108}
109109
110+ constexpr size_t kGroupSize = 128 ; // subchannel
111+
112+ void QuantizeGroup (const float * HWY_RESTRICT in,
113+ TensorStatsAccumulator& my_stats) {
114+ namespace hn = hwy::HWY_NAMESPACE;
115+ const hn::ScalableTag<float > df;
116+ using VF = hn::Vec<decltype (df)>;
117+ using MF = hn::Mask<decltype (df)>;
118+ const hn::ScalableTag<double > dd;
119+ using VD = hn::Vec<decltype (dd)>;
120+ HWY_LANES_CONSTEXPR size_t NF = hn::Lanes (df);
121+ HWY_ALIGN float enc[kGroupSize ];
122+ HWY_ALIGN float dec[kGroupSize ];
123+ HWY_ALIGN float all_snr[kGroupSize ];
124+ HWY_DASSERT (kGroupSize % NF == 0 ); // No remainder handling required.
125+
126+ const VF k0 = hn::Zero (df);
127+ const VF k1 = hn::Set (df, 1 .0f );
128+
129+ // Scan for min/max for quantization.
130+ VF vmin = hn::Set (df, hwy::HighestValue<float >());
131+ VF vmax = hn::Set (df, hwy::LowestValue<float >());
132+ for (size_t i = 0 ; i < kGroupSize ; i += NF) {
133+ const VF v = hn::Load (df, in + i);
134+ vmin = hn::Min (vmin, v);
135+ vmax = hn::Max (vmax, v);
136+ }
137+ const float min = hn::ReduceMin (df, vmin);
138+ const float max = hn::ReduceMax (df, vmax);
139+ // Avoid division by zero during quantization.
140+ if (max == min) return ;
141+
142+ // Distortion stats.
143+ VF vsum_err = hn::Zero (df);
144+ VD sum_log_snr0 = hn::Zero (dd);
145+ VD sum_log_snr1 = hn::Zero (dd);
146+ size_t num_snr = 0 ;
147+
148+ // Unclipped asymmetric quantization (for activations).
149+ const VF scale = hn::Set (df, 255 .0f / (max - min));
150+ const VF inv_scale = hn::Div (k1, scale);
151+ const VF zeropoint = hn::Sub (hn::Round (hn::Mul (hn::Set (df, -min), scale)),
152+ hn::Set (df, 128 .0f ));
153+ const VF dq_sub = hn::Mul (zeropoint, inv_scale); // For MulSub.
154+ for (size_t i = 0 ; i < kGroupSize ; i += NF) {
155+ const VF v = hn::Load (df, in + i);
156+ const VF q = hn::Round (hn::MulAdd (v, scale, zeropoint));
157+ hn::Store (q, df, enc + i);
158+ // Dequantize.
159+ const VF d = hn::MulSub (q, inv_scale, dq_sub);
160+ hn::Store (d, df, dec + i);
161+
162+ const VF err = hn::AbsDiff (v, d); // L1
163+ vsum_err = hn::Add (vsum_err, err);
164+
165+ // For preventing division by zero. However, we still want to
166+ // clamp snr because it could be very high (>1E3 when most
167+ // elements are lossless).
168+ const MF has_err = hn::Gt (err, k0);
169+ const VF rel = hn::MaskedDivOr (k0, has_err, hn::Abs (v), err);
170+ // SNR = 1 + abs/L1, with cap on the latter term.
171+ const VF snr = hn::Add (k1, hn::Min (rel, hn::Set (df, 300 .f )));
172+ hn::Store (snr, df, all_snr + i);
173+ // Where `has_err` is false, `snr` elements are 1 and log(1) is zero, hence
174+ // they do not affect sum_log. However, very high errors also result in
175+ // snr=1, which drags down the average because `sum_log` is increased.
176+ num_snr += hn::CountTrue (df, has_err);
177+
178+ const VD log_snr0 = hn::Log (dd, hn::PromoteLowerTo (dd, snr));
179+ const VD log_snr1 = hn::Log (dd, hn::PromoteUpperTo (dd, snr));
180+ sum_log_snr0 = hn::Add (sum_log_snr0, log_snr0);
181+ sum_log_snr1 = hn::Add (sum_log_snr1, log_snr1);
182+ }
183+
184+ const float sum_err = hn::ReduceSum (df, vsum_err);
185+ const float avg_L1 = sum_err / static_cast <float >(kGroupSize );
186+ const double sum_log = hn::ReduceSum (dd, hn::Add (sum_log_snr0, sum_log_snr1));
187+ // SNR >= 1, hence log >= 0.
188+ HWY_ASSERT (sum_log >= 0.0 );
189+ if (num_snr == 0 ) { // Avoid division by zero.
190+ // It can happen that dequantization is lossless, i.e. SNR is
191+ // infinite; skip such groups.
192+ HWY_ASSERT (sum_err == 0 .0f );
193+ return ;
194+ }
195+ // Signal to noise ratio (Shannon's channel capacity, NOT the
196+ // L2-based and logarithmic PSNR)
197+ const float snr = std::exp (sum_log / static_cast <double >(num_snr));
198+
199+ my_stats.NotifyGroup (avg_L1, snr);
200+ }
201+
110202// First dispatch to the type, then parallel over rows, then vectorized
111203// decompress and Notify for each value.
112204void UpdateStatsT (TensorStats& stats, size_t layer_idx,
@@ -138,37 +230,38 @@ void UpdateStatsT(TensorStats& stats, size_t layer_idx,
138230 my_stats.NotifyCond (ConditionNumber (row, cols));
139231
140232 namespace hn = hwy::HWY_NAMESPACE;
141- hn::ScalableTag<float > df;
233+ const hn::ScalableTag<float > df;
142234 using VF = hn::Vec<decltype (df)>;
143235 HWY_LANES_CONSTEXPR size_t NF = hn::Lanes (df);
144- HWY_ALIGN float buf[2 * hn::MaxLanes (df)];
236+ HWY_ALIGN float buf[kGroupSize ];
237+ size_t buf_filled = 0 ;
145238
146239 size_t packed_ofs = 0 ;
147240 if (cols >= 2 * NF) {
148241 for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
149242 VF v0, v1;
150243 Decompress2 (df, packed, packed_ofs, v0, v1);
151- hn::Store (v0, df, buf);
152- hn::Store (v1, df, buf + NF);
153- const VF min_mag = hn::Min (hn::Abs (v0), hn::Abs (v1));
154- const VF max_mag = hn::Max (hn::Abs (v0), hn::Abs (v1));
155- const float min = hn::ReduceMin (df, min_mag);
156- if (min != 0 .0f ) { // Avoid division by zero.
157- my_stats.NotifyGroup (min, hn::ReduceMax (df, max_mag));
158- }
244+ hn::Store (v0, df, buf + buf_filled);
245+ hn::Store (v1, df, buf + buf_filled + NF);
246+ buf_filled += 2 * NF;
247+ if (buf_filled == kGroupSize ) {
248+ QuantizeGroup (buf, my_stats);
249+
250+ for (size_t i = 0 ; i < kGroupSize ; ++i) {
251+ my_stats.Notify (buf[i], row_idx, packed_ofs + i);
252+ }
253+ my_stats.NotifyCorr (Correlation (buf, kGroupSize ));
159254
160- for (size_t i = 0 ; i < 2 * NF; ++i) {
161- my_stats.Notify (buf[i], row_idx, packed_ofs + i);
255+ buf_filled = 0 ;
162256 }
163- my_stats.NotifyCorr (Correlation (buf, 2 * NF));
164257 }
165258 }
166259
167260 // Zero to two vectors remaining.
168261 for (; packed_ofs < cols; packed_ofs += NF) {
169262 const size_t remaining = HWY_MIN (NF, cols - packed_ofs);
170263 DecompressAndZeroPad (df, packed, packed_ofs, buf, remaining);
171- // Skip NotifyGroup for this partial group .
264+ // Skip QuantizeGroup because it requires full groups .
172265 for (size_t i = 0 ; i < remaining; ++i) {
173266 my_stats.Notify (buf[i], row_idx, packed_ofs + i);
174267 }
0 commit comments