Skip to content

Commit 16aa87b

Browse files
Nicoshevmeta-codesync[bot]
authored andcommitted
Add NEON-based FloatOrHalfToFused8BitRowwiseQuantizedSBFloat (#5089)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2098 Pull Request resolved: #5089 Adding NEON translation of FloatOrHalfToFused8BitRowwiseQuantizedSBFloat, used by Ads Performance improves by an order of magnitude: Before: bit_rate, rows, cols, elems_per_usec, GB/Sec 8, 100, 16, 378.68, 1.51 8, 100, 64, 286.91, 1.15 8, 100, 128, 262.06, 1.05 8, 100, 256, 251.34, 1.01 8, 100, 512, 244.92, 0.98 8, 100, 1024, 237.35, 0.95 8, 100, 2048, 230.83, 0.92 8, 120, 16, 378.70, 1.51 8, 120, 64, 286.72, 1.15 8, 120, 128, 263.40, 1.05 8, 120, 256, 251.58, 1.01 8, 120, 512, 245.30, 0.98 8, 120, 1024, 238.17, 0.95 8, 120, 2048, 230.69, 0.92 8, 1000, 16, 392.85, 1.57 8, 1000, 64, 294.35, 1.18 8, 1000, 128, 264.35, 1.06 8, 1000, 256, 252.13, 1.01 8, 1000, 512, 245.50, 0.98 8, 1000, 1024, 241.61, 0.97 8, 1000, 2048, 231.39, 0.93 After: bit_rate, rows, cols, elems_per_usec, GB/Sec 8, 100, 16, 1855.59, 7.42 8, 100, 64, 2615.43, 10.46 8, 100, 128, 3134.34, 12.54 8, 100, 256, 2610.72, 10.44 8, 100, 512, 3065.20, 12.26 8, 100, 1024, 3535.29, 14.14 8, 100, 2048, 3757.66, 15.03 8, 120, 16, 1991.94, 7.97 8, 120, 64, 2971.25, 11.89 8, 120, 128, 3403.37, 13.61 8, 120, 256, 2750.87, 11.00 8, 120, 512, 3272.63, 13.09 8, 120, 1024, 3618.98, 14.48 8, 120, 2048, 3848.59, 15.39 8, 1000, 16, 2329.11, 9.32 8, 1000, 64, 3068.76, 12.28 8, 1000, 128, 3678.86, 14.72 8, 1000, 256, 4440.37, 17.76 8, 1000, 512, 4558.70, 18.23 8, 1000, 1024, 4620.94, 18.48 8, 1000, 2048, 3898.84, 15.60 Reviewed By: mcfi Differential Revision: D86236406 fbshipit-source-id: 12c20cbdbbc9b0674ccca8e1aa598b7de144dea9
1 parent 17a7653 commit 16aa87b

File tree

3 files changed

+226
-29
lines changed

3 files changed

+226
-29
lines changed

include/fbgemm/QuantUtilsNeon.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ namespace fbgemm {
2222
// Utility functions
2323
////////////////////////////////////////////////////////////////////////////////
2424

25+
template <typename InputType>
26+
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon(
27+
const InputType* input,
28+
size_t input_rows,
29+
int input_columns,
30+
uint8_t* output);
31+
2532
template <typename OutputType>
2633
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
2734
const std::uint8_t* input,

src/QuantUtils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,10 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
714714
int input_columns,
715715
std::uint8_t* output,
716716
const InputType* rowwise_min_max) {
717+
#if HAVE_SVE
718+
FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon<InputType>(
719+
input, input_rows, input_columns, output);
720+
#else
717721
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
718722
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
719723
FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<InputType>(
@@ -723,6 +727,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
723727
FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<InputType>(
724728
input, input_rows, input_columns, output);
725729
}
730+
#endif
726731
}
727732

728733
template <typename OutputType, bool is_uint16_t_of_type_bf16>

src/QuantUtilsNeon.cc

Lines changed: 214 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
#include "fbgemm/Utils.h"
1212

1313
#define FBGEMM_EXPORTS
14+
#include <arm_fp16.h> // @manual
1415
#include <arm_neon.h> // @manual
1516
#if HAVE_SVE
17+
#include <arm_neon_sve_bridge.h> // @manual
1618
#include <arm_sve.h> // @manual
1719
#endif
1820

19-
#include <arm_neon_sve_bridge.h> // @manual
2021
#include <algorithm> //for std::min/std::max
2122
#include <cassert> //for assert
2223
#include <cfloat> // for FLT_MAX
@@ -32,41 +33,48 @@ namespace fbgemm {
3233
using namespace std;
3334
////////////////////////////////////////////////////////////////////////////////
3435
// Utility functions
35-
36-
void FindMinMax(const float* m, float* min, float* max, int64_t len) {
37-
if (__builtin_expect(len <= 0, 0)) {
38-
*min = 0.0f;
39-
*max = 0.0f;
40-
return;
41-
}
42-
36+
static inline void
37+
FindMinMaxImpl_f32(const float* m, float* min, float* max, uint64_t count) {
4338
float first = *m;
4439

40+
float tmp_min_s = first;
41+
float tmp_max_s = first;
42+
4543
float32x4_t temp_min_0 = vdupq_n_f32(first);
4644
float32x4_t temp_min_1 = vdupq_n_f32(first);
4745
float32x4_t temp_max_0 = vdupq_n_f32(first);
4846
float32x4_t temp_max_1 = vdupq_n_f32(first);
49-
uint64_t i = 0;
50-
uint64_t count = static_cast<uint64_t>(len);
51-
uint64_t loopBound = count - (count % 8);
52-
53-
for (; i < loopBound; i += 8) {
54-
float32x4_t v0 = vld1q_f32(m + i);
55-
float32x4_t v1 = vld1q_f32(m + i + 4);
56-
temp_min_0 = vminq_f32(temp_min_0, v0);
57-
temp_min_1 = vminq_f32(temp_min_1, v1);
58-
temp_max_0 = vmaxq_f32(temp_max_0, v0);
59-
temp_max_1 = vmaxq_f32(temp_max_1, v1);
47+
constexpr uint64_t kItemsPerIter = 8;
48+
uint64_t loopIters = count / kItemsPerIter;
49+
uint64_t loopRemainder = count % kItemsPerIter;
50+
51+
if (__builtin_expect(loopIters > 0, 1)) {
52+
do {
53+
float32x4_t v0 = vld1q_f32(m);
54+
float32x4_t v1 = vld1q_f32(m + 4);
55+
m += kItemsPerIter;
56+
loopIters -= 1;
57+
temp_min_0 = vminq_f32(temp_min_0, v0);
58+
temp_min_1 = vminq_f32(temp_min_1, v1);
59+
temp_max_0 = vmaxq_f32(temp_max_0, v0);
60+
temp_max_1 = vmaxq_f32(temp_max_1, v1);
61+
} while (loopIters > 0);
62+
63+
temp_min_0 = vminq_f32(temp_min_0, temp_min_1);
64+
temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1);
65+
66+
tmp_min_s = vminvq_f32(temp_min_0);
67+
tmp_max_s = vmaxvq_f32(temp_max_0);
6068
}
6169

62-
temp_min_0 = vminq_f32(temp_min_0, temp_min_1);
63-
temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1);
64-
65-
float tmp_min_s = vminvq_f32(temp_min_0);
66-
float tmp_max_s = vmaxvq_f32(temp_max_0);
67-
68-
for (; i < count; i++) {
69-
float tmp = *m;
70+
#ifdef __clang__
71+
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
72+
#elif defined(__GNUC__)
73+
#pragma GCC novector unroll 0
74+
#endif
75+
while (loopRemainder > 0) {
76+
float tmp = *m++;
77+
loopRemainder -= 1;
7078
tmp_min_s = std::min(tmp_min_s, tmp);
7179
tmp_max_s = std::max(tmp_max_s, tmp);
7280
}
@@ -75,8 +83,180 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) {
7583
*max = tmp_max_s;
7684
}
7785

86+
void FindMinMax(const float* m, float* min, float* max, int64_t len) {
87+
if (__builtin_expect(len <= 0, 0)) {
88+
*min = 0.0f;
89+
*max = 0.0f;
90+
return;
91+
}
92+
93+
FindMinMaxImpl_f32(m, min, max, static_cast<uint64_t>(len));
94+
}
95+
7896
#if HAVE_SVE
7997

98+
static inline void
99+
FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) {
100+
float16_t first = *m;
101+
102+
float16_t tmp_min_s = first;
103+
float16_t tmp_max_s = first;
104+
105+
float16x8_t temp_min_0 = vdupq_n_f16(first);
106+
float16x8_t temp_min_1 = vdupq_n_f16(first);
107+
float16x8_t temp_max_0 = vdupq_n_f16(first);
108+
float16x8_t temp_max_1 = vdupq_n_f16(first);
109+
constexpr uint64_t kItemsPerIter = 16;
110+
uint64_t loopIters = count / kItemsPerIter;
111+
uint64_t loopRemainder = count % kItemsPerIter;
112+
113+
if (__builtin_expect(loopIters > 0, 1)) {
114+
do {
115+
float16x8_t v0 = vld1q_f16(m);
116+
float16x8_t v1 = vld1q_f16(m + 8);
117+
m += kItemsPerIter;
118+
loopIters -= 1;
119+
temp_min_0 = vminq_f16(temp_min_0, v0);
120+
temp_min_1 = vminq_f16(temp_min_1, v1);
121+
temp_max_0 = vmaxq_f16(temp_max_0, v0);
122+
temp_max_1 = vmaxq_f16(temp_max_1, v1);
123+
} while (loopIters > 0);
124+
125+
temp_min_0 = vminq_f16(temp_min_0, temp_min_1);
126+
temp_max_0 = vmaxq_f16(temp_max_0, temp_max_1);
127+
128+
tmp_min_s = vminvq_f16(temp_min_0);
129+
tmp_max_s = vmaxvq_f16(temp_max_0);
130+
}
131+
132+
#ifdef __clang__
133+
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
134+
#elif defined(__GNUC__)
135+
#pragma GCC novector unroll 0
136+
#endif
137+
while (loopRemainder > 0) {
138+
float16_t tmp = *m++;
139+
loopRemainder -= 1;
140+
tmp_min_s = vminh_f16(tmp_min_s, tmp);
141+
tmp_max_s = vmaxh_f16(tmp_max_s, tmp);
142+
}
143+
144+
*min = static_cast<float>(tmp_min_s);
145+
*max = static_cast<float>(tmp_max_s);
146+
}
147+
148+
template <typename InputType>
149+
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon(
150+
const InputType* input,
151+
size_t input_rows,
152+
int input_columns,
153+
uint8_t* output) {
154+
constexpr float kEpsilon = 1e-8f;
155+
156+
if (input_rows == 0 || input_columns <= 0) {
157+
return;
158+
}
159+
160+
uint64_t column_count = static_cast<uint64_t>(input_columns);
161+
162+
const uint64_t output_columns = column_count + 2 * sizeof(float);
163+
164+
for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) {
165+
const InputType* input_row = input + row * column_count;
166+
uint8_t* output_row = output + row * output_columns;
167+
168+
float* output_row_scale_bias =
169+
reinterpret_cast<float*>(output_row + column_count);
170+
171+
float minimum_element;
172+
float maximum_element;
173+
if constexpr (std::is_same<InputType, float>()) {
174+
FindMinMaxImpl_f32(
175+
input_row, &minimum_element, &maximum_element, column_count);
176+
} else {
177+
FindMinMaxImpl_f16(
178+
reinterpret_cast<const float16_t*>(input_row),
179+
&minimum_element,
180+
&maximum_element,
181+
column_count);
182+
}
183+
float range = maximum_element - minimum_element;
184+
185+
const auto inverse_scale = 255.0f / (range + kEpsilon);
186+
187+
float32x4_t inverse_scale_v = vdupq_n_f32(inverse_scale);
188+
float32x4_t min_v = vdupq_n_f32(minimum_element);
189+
190+
constexpr uint64_t kItemsPerIter = 8;
191+
uint64_t loopIters = column_count / kItemsPerIter;
192+
uint64_t loopRemainder = column_count % kItemsPerIter;
193+
194+
output_row_scale_bias[0] = range / 255.0f;
195+
output_row_scale_bias[1] = minimum_element;
196+
197+
while (__builtin_expect(loopIters > 0, 1)) {
198+
float32x4_t v0;
199+
float32x4_t v1;
200+
201+
if constexpr (std::is_same<InputType, float>()) {
202+
v0 = vld1q_f32(input_row);
203+
v1 = vld1q_f32(input_row + 4);
204+
} else {
205+
float16x8_t h0 =
206+
vld1q_f16(reinterpret_cast<const float16_t*>(input_row));
207+
v0 = vcvt_f32_f16(vget_low_f16(h0));
208+
v1 = vcvt_high_f32_f16(h0);
209+
}
210+
211+
input_row += kItemsPerIter;
212+
loopIters -= 1;
213+
214+
v0 = vsubq_f32(v0, min_v);
215+
v1 = vsubq_f32(v1, min_v);
216+
217+
v0 = vmulq_f32(v0, inverse_scale_v);
218+
v1 = vmulq_f32(v1, inverse_scale_v);
219+
220+
int32x4_t i0 = vcvtnq_s32_f32(v0);
221+
int32x4_t i1 = vcvtnq_s32_f32(v1);
222+
223+
svst1b_s32(
224+
svptrue_b8(),
225+
reinterpret_cast<int8_t*>(output_row),
226+
svset_neonq_s32(svundef_s32(), i0));
227+
svst1b_s32(
228+
svptrue_b8(),
229+
reinterpret_cast<int8_t*>(output_row + 4),
230+
svset_neonq_s32(svundef_s32(), i1));
231+
232+
output_row += kItemsPerIter;
233+
}
234+
235+
#ifdef __clang__
236+
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
237+
#elif defined(__GNUC__)
238+
#pragma GCC novector unroll 0
239+
#endif
240+
while (loopRemainder > 0) {
241+
float32x4_t v0;
242+
if constexpr (std::is_same<InputType, float>()) {
243+
v0[0] = *input_row++;
244+
} else {
245+
v0[0] =
246+
static_cast<float>(*reinterpret_cast<const float16_t*>(input_row));
247+
input_row += 1;
248+
}
249+
loopRemainder -= 1;
250+
v0 = vsubq_f32(v0, min_v);
251+
v0 = vmulq_f32(v0, inverse_scale_v);
252+
int32x4_t i0 = vcvtnq_s32_f32(v0);
253+
*output_row = i0[0];
254+
output_row += 1;
255+
}
256+
257+
} // for each row
258+
}
259+
80260
template <typename OutputType>
81261
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
82262
const std::uint8_t* input,
@@ -179,7 +359,12 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
179359
const std::uint8_t* input, \
180360
size_t input_rows, \
181361
int input_columns, \
182-
type* output);
362+
type* output); \
363+
template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon<type>( \
364+
const type* input, \
365+
size_t input_rows, \
366+
int input_columns, \
367+
uint8_t* output);
183368

184369
// clang-format off
185370
INSTANTIATE_QuantizationNeonFunctions8Bits(float)

0 commit comments

Comments
 (0)