1111#include " fbgemm/Utils.h"
1212
1313#define FBGEMM_EXPORTS
14+ #include < arm_fp16.h> // @manual
1415#include < arm_neon.h> // @manual
1516#if HAVE_SVE
17+ #include < arm_neon_sve_bridge.h> // @manual
1618#include < arm_sve.h> // @manual
1719#endif
1820
19- #include < arm_neon_sve_bridge.h> // @manual
2021#include < algorithm> // for std::min/std::max
2122#include < cassert> // for assert
2223#include < cfloat> // for FLT_MAX
@@ -32,41 +33,48 @@ namespace fbgemm {
3233using namespace std ;
3334// //////////////////////////////////////////////////////////////////////////////
3435// Utility functions
35-
36- void FindMinMax (const float * m, float * min, float * max, int64_t len) {
37- if (__builtin_expect (len <= 0 , 0 )) {
38- *min = 0 .0f ;
39- *max = 0 .0f ;
40- return ;
41- }
42-
36+ static inline void
37+ FindMinMaxImpl_f32 (const float * m, float * min, float * max, uint64_t count) {
4338 float first = *m;
4439
40+ float tmp_min_s = first;
41+ float tmp_max_s = first;
42+
4543 float32x4_t temp_min_0 = vdupq_n_f32 (first);
4644 float32x4_t temp_min_1 = vdupq_n_f32 (first);
4745 float32x4_t temp_max_0 = vdupq_n_f32 (first);
4846 float32x4_t temp_max_1 = vdupq_n_f32 (first);
49- uint64_t i = 0 ;
50- uint64_t count = static_cast <uint64_t >(len);
51- uint64_t loopBound = count - (count % 8 );
52-
53- for (; i < loopBound; i += 8 ) {
54- float32x4_t v0 = vld1q_f32 (m + i);
55- float32x4_t v1 = vld1q_f32 (m + i + 4 );
56- temp_min_0 = vminq_f32 (temp_min_0, v0);
57- temp_min_1 = vminq_f32 (temp_min_1, v1);
58- temp_max_0 = vmaxq_f32 (temp_max_0, v0);
59- temp_max_1 = vmaxq_f32 (temp_max_1, v1);
47+ constexpr uint64_t kItemsPerIter = 8 ;
48+ uint64_t loopIters = count / kItemsPerIter ;
49+ uint64_t loopRemainder = count % kItemsPerIter ;
50+
51+ if (__builtin_expect (loopIters > 0 , 1 )) {
52+ do {
53+ float32x4_t v0 = vld1q_f32 (m);
54+ float32x4_t v1 = vld1q_f32 (m + 4 );
55+ m += kItemsPerIter ;
56+ loopIters -= 1 ;
57+ temp_min_0 = vminq_f32 (temp_min_0, v0);
58+ temp_min_1 = vminq_f32 (temp_min_1, v1);
59+ temp_max_0 = vmaxq_f32 (temp_max_0, v0);
60+ temp_max_1 = vmaxq_f32 (temp_max_1, v1);
61+ } while (loopIters > 0 );
62+
63+ temp_min_0 = vminq_f32 (temp_min_0, temp_min_1);
64+ temp_max_0 = vmaxq_f32 (temp_max_0, temp_max_1);
65+
66+ tmp_min_s = vminvq_f32 (temp_min_0);
67+ tmp_max_s = vmaxvq_f32 (temp_max_0);
6068 }
6169
62- temp_min_0 = vminq_f32 (temp_min_0, temp_min_1);
63- temp_max_0 = vmaxq_f32 (temp_max_0, temp_max_1);
64-
65- float tmp_min_s = vminvq_f32 (temp_min_0);
66- float tmp_max_s = vmaxvq_f32 (temp_max_0);
67-
68- for (; i < count; i++) {
69- float tmp = *m ;
70+ # ifdef __clang__
71+ # pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
72+ # elif defined(__GNUC__)
73+ # pragma GCC novector unroll 0
74+ # endif
75+ while (loopRemainder > 0 ) {
76+ float tmp = *m++;
77+ loopRemainder -= 1 ;
7078 tmp_min_s = std::min (tmp_min_s, tmp);
7179 tmp_max_s = std::max (tmp_max_s, tmp);
7280 }
@@ -75,8 +83,180 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) {
7583 *max = tmp_max_s;
7684}
7785
86+ void FindMinMax (const float * m, float * min, float * max, int64_t len) {
87+ if (__builtin_expect (len <= 0 , 0 )) {
88+ *min = 0 .0f ;
89+ *max = 0 .0f ;
90+ return ;
91+ }
92+
93+ FindMinMaxImpl_f32 (m, min, max, static_cast <uint64_t >(len));
94+ }
95+
7896#if HAVE_SVE
7997
98+ static inline void
99+ FindMinMaxImpl_f16 (const float16_t * m, float * min, float * max, uint64_t count) {
100+ float16_t first = *m;
101+
102+ float16_t tmp_min_s = first;
103+ float16_t tmp_max_s = first;
104+
105+ float16x8_t temp_min_0 = vdupq_n_f16 (first);
106+ float16x8_t temp_min_1 = vdupq_n_f16 (first);
107+ float16x8_t temp_max_0 = vdupq_n_f16 (first);
108+ float16x8_t temp_max_1 = vdupq_n_f16 (first);
109+ constexpr uint64_t kItemsPerIter = 16 ;
110+ uint64_t loopIters = count / kItemsPerIter ;
111+ uint64_t loopRemainder = count % kItemsPerIter ;
112+
113+ if (__builtin_expect (loopIters > 0 , 1 )) {
114+ do {
115+ float16x8_t v0 = vld1q_f16 (m);
116+ float16x8_t v1 = vld1q_f16 (m + 8 );
117+ m += kItemsPerIter ;
118+ loopIters -= 1 ;
119+ temp_min_0 = vminq_f16 (temp_min_0, v0);
120+ temp_min_1 = vminq_f16 (temp_min_1, v1);
121+ temp_max_0 = vmaxq_f16 (temp_max_0, v0);
122+ temp_max_1 = vmaxq_f16 (temp_max_1, v1);
123+ } while (loopIters > 0 );
124+
125+ temp_min_0 = vminq_f16 (temp_min_0, temp_min_1);
126+ temp_max_0 = vmaxq_f16 (temp_max_0, temp_max_1);
127+
128+ tmp_min_s = vminvq_f16 (temp_min_0);
129+ tmp_max_s = vmaxvq_f16 (temp_max_0);
130+ }
131+
132+ #ifdef __clang__
133+ #pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
134+ #elif defined(__GNUC__)
135+ #pragma GCC novector unroll 0
136+ #endif
137+ while (loopRemainder > 0 ) {
138+ float16_t tmp = *m++;
139+ loopRemainder -= 1 ;
140+ tmp_min_s = vminh_f16 (tmp_min_s, tmp);
141+ tmp_max_s = vmaxh_f16 (tmp_max_s, tmp);
142+ }
143+
144+ *min = static_cast <float >(tmp_min_s);
145+ *max = static_cast <float >(tmp_max_s);
146+ }
147+
148+ template <typename InputType>
149+ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon (
150+ const InputType* input,
151+ size_t input_rows,
152+ int input_columns,
153+ uint8_t * output) {
154+ constexpr float kEpsilon = 1e-8f ;
155+
156+ if (input_rows == 0 || input_columns <= 0 ) {
157+ return ;
158+ }
159+
160+ uint64_t column_count = static_cast <uint64_t >(input_columns);
161+
162+ const uint64_t output_columns = column_count + 2 * sizeof (float );
163+
164+ for (size_t row = 0 ; __builtin_expect (row < input_rows, 1 ); ++row) {
165+ const InputType* input_row = input + row * column_count;
166+ uint8_t * output_row = output + row * output_columns;
167+
168+ float * output_row_scale_bias =
169+ reinterpret_cast <float *>(output_row + column_count);
170+
171+ float minimum_element;
172+ float maximum_element;
173+ if constexpr (std::is_same<InputType, float >()) {
174+ FindMinMaxImpl_f32 (
175+ input_row, &minimum_element, &maximum_element, column_count);
176+ } else {
177+ FindMinMaxImpl_f16 (
178+ reinterpret_cast <const float16_t *>(input_row),
179+ &minimum_element,
180+ &maximum_element,
181+ column_count);
182+ }
183+ float range = maximum_element - minimum_element;
184+
185+ const auto inverse_scale = 255 .0f / (range + kEpsilon );
186+
187+ float32x4_t inverse_scale_v = vdupq_n_f32 (inverse_scale);
188+ float32x4_t min_v = vdupq_n_f32 (minimum_element);
189+
190+ constexpr uint64_t kItemsPerIter = 8 ;
191+ uint64_t loopIters = column_count / kItemsPerIter ;
192+ uint64_t loopRemainder = column_count % kItemsPerIter ;
193+
194+ output_row_scale_bias[0 ] = range / 255 .0f ;
195+ output_row_scale_bias[1 ] = minimum_element;
196+
197+ while (__builtin_expect (loopIters > 0 , 1 )) {
198+ float32x4_t v0;
199+ float32x4_t v1;
200+
201+ if constexpr (std::is_same<InputType, float >()) {
202+ v0 = vld1q_f32 (input_row);
203+ v1 = vld1q_f32 (input_row + 4 );
204+ } else {
205+ float16x8_t h0 =
206+ vld1q_f16 (reinterpret_cast <const float16_t *>(input_row));
207+ v0 = vcvt_f32_f16 (vget_low_f16 (h0));
208+ v1 = vcvt_high_f32_f16 (h0);
209+ }
210+
211+ input_row += kItemsPerIter ;
212+ loopIters -= 1 ;
213+
214+ v0 = vsubq_f32 (v0, min_v);
215+ v1 = vsubq_f32 (v1, min_v);
216+
217+ v0 = vmulq_f32 (v0, inverse_scale_v);
218+ v1 = vmulq_f32 (v1, inverse_scale_v);
219+
220+ int32x4_t i0 = vcvtnq_s32_f32 (v0);
221+ int32x4_t i1 = vcvtnq_s32_f32 (v1);
222+
223+ svst1b_s32 (
224+ svptrue_b8 (),
225+ reinterpret_cast <int8_t *>(output_row),
226+ svset_neonq_s32 (svundef_s32 (), i0));
227+ svst1b_s32 (
228+ svptrue_b8 (),
229+ reinterpret_cast <int8_t *>(output_row + 4 ),
230+ svset_neonq_s32 (svundef_s32 (), i1));
231+
232+ output_row += kItemsPerIter ;
233+ }
234+
235+ #ifdef __clang__
236+ #pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
237+ #elif defined(__GNUC__)
238+ #pragma GCC novector unroll 0
239+ #endif
240+ while (loopRemainder > 0 ) {
241+ float32x4_t v0;
242+ if constexpr (std::is_same<InputType, float >()) {
243+ v0[0 ] = *input_row++;
244+ } else {
245+ v0[0 ] =
246+ static_cast <float >(*reinterpret_cast <const float16_t *>(input_row));
247+ input_row += 1 ;
248+ }
249+ loopRemainder -= 1 ;
250+ v0 = vsubq_f32 (v0, min_v);
251+ v0 = vmulq_f32 (v0, inverse_scale_v);
252+ int32x4_t i0 = vcvtnq_s32_f32 (v0);
253+ *output_row = i0[0 ];
254+ output_row += 1 ;
255+ }
256+
257+ } // for each row
258+ }
259+
80260template <typename OutputType>
81261void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon (
82262 const std::uint8_t * input,
@@ -179,7 +359,12 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
179359 const std::uint8_t * input, \
180360 size_t input_rows, \
181361 int input_columns, \
182- type* output);
362+ type* output); \
363+ template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon<type>( \
364+ const type* input, \
365+ size_t input_rows, \
366+ int input_columns, \
367+ uint8_t * output);
183368
184369// clang-format off
185370INSTANTIATE_QuantizationNeonFunctions8Bits (float )
0 commit comments