Skip to content

Commit 9008f84

Browse files
Support int8 output for scaled_embedding_bag (#3231)
* re-enable scaled_embedding_bag * only support fp32 out_dtype * support int8 output * support int8 fallback path * fix lint * fix lint and clang * refine scale Co-authored-by: Xia Weiwen <[email protected]> * refine code Co-authored-by: Xia Weiwen <[email protected]> * refine code --------- Co-authored-by: Xia Weiwen <[email protected]>
1 parent 7d91f11 commit 9008f84

File tree

3 files changed

+154
-59
lines changed

3 files changed

+154
-59
lines changed

test/test_ops.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -864,9 +864,13 @@ def test_swizzle_mm():
864864

865865

866866
def _test_scaled_embedding_bag_cpu_helper(
867-
multi_hot, batch_size, vector_size, index_type, qtype
867+
multi_hot,
868+
batch_size,
869+
vector_size,
870+
index_type,
871+
qtype,
872+
out_dtype=torch.float,
868873
):
869-
dtype = torch.float32
870874
include_last_offset = True
871875
mode = "sum"
872876

@@ -883,7 +887,7 @@ def _test_scaled_embedding_bag_cpu_helper(
883887
1000,
884888
vector_size,
885889
mode=mode,
886-
dtype=dtype,
890+
dtype=torch.float,
887891
include_last_offset=include_last_offset,
888892
)
889893
if qtype == torch.int8:
@@ -894,17 +898,25 @@ def _test_scaled_embedding_bag_cpu_helper(
894898
qweight = m.weight.data.to(qtype)
895899
m.weight.data = qweight.to(m.weight.dtype)
896900

901+
out_scale = 1.0
902+
if out_dtype == torch.int8:
903+
out_scale = 2.0
904+
897905
with torch.no_grad():
898906
refe_out = m.forward(indices, offsets) * weight_scale
907+
if out_dtype == torch.int8:
908+
refe_out = torch.round(refe_out / out_scale).to(torch.int32)
909+
refe_out = torch.clamp(refe_out, -128, 127).to(out_dtype)
899910
test_out = torch.ops.torchao._scaled_embedding_bag(
900911
qweight,
901912
indices,
902913
offsets,
903914
weight_scale,
904-
1.0,
915+
out_scale,
905916
mode_enum,
906917
include_last_offset,
907-
).to(dtype)
918+
out_dtype,
919+
)
908920
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
909921

910922

@@ -918,9 +930,15 @@ def _test_scaled_embedding_bag_cpu_helper(
918930
ids=str,
919931
)
920932
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
921-
_test_scaled_embedding_bag_cpu_helper(
922-
multi_hot, batch_size, vector_size, index_type, torch.int8
923-
)
933+
for out_dtype in [torch.float, torch.int8]:
934+
_test_scaled_embedding_bag_cpu_helper(
935+
multi_hot,
936+
batch_size,
937+
vector_size,
938+
index_type,
939+
torch.int8,
940+
out_dtype,
941+
)
924942

925943

926944
@pytest.mark.skipif(

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 126 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,38 @@
66
#include <c10/util/Unroll.h>
77
#include <torch/all.h>
88

9+
#define QTYPE_DISPATCH(TYPE, ...) \
10+
[&]() { \
11+
switch (TYPE) { \
12+
case c10::ScalarType::Float8_e4m3fn: { \
13+
using data_t = at::Float8_e4m3fn; \
14+
return __VA_ARGS__(); \
15+
} \
16+
case c10::ScalarType::Char: { \
17+
using data_t = int8_t; \
18+
return __VA_ARGS__(); \
19+
} \
20+
default: \
21+
TORCH_CHECK(false, "scaled_embeding_bag: unsupport qtype"); \
22+
} \
23+
}()
24+
25+
#define OUTTYPE_DISPATCH(TYPE, ...) \
26+
[&]() { \
27+
switch (TYPE) { \
28+
case c10::ScalarType::Float: { \
29+
using output_t = float; \
30+
return __VA_ARGS__(); \
31+
} \
32+
case c10::ScalarType::Char: { \
33+
using output_t = int8_t; \
34+
return __VA_ARGS__(); \
35+
} \
36+
default: \
37+
TORCH_CHECK(false, "scaled_embeding_bag: unsupport output type"); \
38+
} \
39+
}()
40+
941
namespace torchao {
1042

1143
namespace {
@@ -53,14 +85,71 @@ static inline CHUNK load_chunk(const int8_t *x) {
5385
x7 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 3));
5486
return {x0, x1, x2, x3, x4, x5, x6, x7};
5587
}
88+
89+
static inline void store_chunk(float *output, CHUNK chunk) {
90+
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
91+
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = chunk;
92+
_mm512_store_ps(output, x0);
93+
_mm512_store_ps(output + 16, x1);
94+
_mm512_store_ps(output + 32, x2);
95+
_mm512_store_ps(output + 48, x3);
96+
_mm512_store_ps(output + 64, x4);
97+
_mm512_store_ps(output + 80, x5);
98+
_mm512_store_ps(output + 96, x6);
99+
_mm512_store_ps(output + 112, x7);
100+
}
101+
102+
static inline void store_chunk(int8_t *output, CHUNK chunk) {
103+
__m512i x00, x64;
104+
__m512i y0, y1, y2, y3, y4, y5, y6, y7;
105+
__m512 f0, f1, f2, f3, f4, f5, f6, f7;
106+
std::tie(f0, f1, f2, f3, f4, f5, f6, f7) = chunk;
107+
y0 = _mm512_cvt_roundps_epi32(
108+
f0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
109+
y1 = _mm512_cvt_roundps_epi32(
110+
f1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
111+
y2 = _mm512_cvt_roundps_epi32(
112+
f2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
113+
y3 = _mm512_cvt_roundps_epi32(
114+
f3, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
115+
y4 = _mm512_cvt_roundps_epi32(
116+
f4, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
117+
y5 = _mm512_cvt_roundps_epi32(
118+
f5, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
119+
y6 = _mm512_cvt_roundps_epi32(
120+
f6, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
121+
y7 = _mm512_cvt_roundps_epi32(
122+
f7, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
123+
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y0), 0);
124+
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y1), 1);
125+
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y2), 2);
126+
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y3), 3);
127+
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y4), 0);
128+
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y5), 1);
129+
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y6), 2);
130+
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y7), 3);
131+
_mm512_store_si512(output, x00);
132+
_mm512_store_si512(output + 64, x64);
133+
}
56134
#endif
57135

58-
template <typename index_t, typename data_t>
136+
static inline void store_elem(float &out, float input) {
137+
out = input;
138+
}
139+
140+
static inline void store_elem(int8_t &out, float input) {
141+
float rounded = std::round(input);
142+
float clamped = std::max(-128.0f, std::min(127.0f, rounded));
143+
int32_t int32_value = static_cast<int32_t>(clamped);
144+
out = static_cast<int8_t>(int32_value);
145+
}
146+
147+
template <typename index_t, typename data_t, typename output_t>
59148
inline void _scaled_embedding_bag_krnl(
60149
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
61150
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
62151
const index_t *offsets, const data_t *weight, const double scale,
63-
float *result, const int64_t num_batch) {
152+
output_t *result, const int64_t num_batch) {
64153
#if defined(CPU_CAPABILITY_AVX512)
65154
if (emb_dim % 128 == 0) {
66155
constexpr int64_t block_dim = 128;
@@ -76,7 +165,7 @@ inline void _scaled_embedding_bag_krnl(
76165
for (int64_t block_id = 0; block_id < num_blocks; block_id++) {
77166
// load first indices
78167
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
79-
float *block_result = result + block_dim * block_id;
168+
output_t *block_result = result + block_dim * block_id;
80169
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk(weight + idx);
81170
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
82171
// add following idx
@@ -100,14 +189,7 @@ inline void _scaled_embedding_bag_krnl(
100189
x6 = _mm512_mul_ps(x6, scale_v);
101190
x7 = _mm512_mul_ps(x7, scale_v);
102191
// store
103-
_mm512_store_ps(block_result, x0);
104-
_mm512_store_ps(block_result + 16, x1);
105-
_mm512_store_ps(block_result + 32, x2);
106-
_mm512_store_ps(block_result + 48, x3);
107-
_mm512_store_ps(block_result + 64, x4);
108-
_mm512_store_ps(block_result + 80, x5);
109-
_mm512_store_ps(block_result + 96, x6);
110-
_mm512_store_ps(block_result + 112, x7);
192+
store_chunk(block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
111193
}
112194
result += num_emb * emb_dim;
113195
}
@@ -127,14 +209,14 @@ inline void _scaled_embedding_bag_krnl(
127209
value += float(weight[idx + d]);
128210
}
129211
value = value * scale;
130-
result[d] = value;
212+
store_elem(result[d], value);
131213
}
132214
result += num_emb * emb_dim;
133215
}
134216
}
135217

136-
template <typename index_t, typename data_t>
137-
void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
218+
template <typename index_t, typename data_t, typename output_t>
219+
void _scaled_embedding_bag(output_t *o_ptr, data_t *w_ptr, index_t *indices_ptr,
138220
index_t *offsets_ptr, int64_t num_batch,
139221
int64_t emb_dim, index_t last_offset, double w_scale,
140222
double o_scale) {
@@ -147,7 +229,7 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
147229
for (int64_t n = 0; n < num_emb; ++n) {
148230
const int64_t bs_begin = b * b_block;
149231
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block);
150-
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
232+
output_t *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
151233
// avoid offsets not include last batch
152234
_scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim,
153235
last_offset, indices_ptr, offsets_ptr, w_ptr,
@@ -156,12 +238,24 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
156238
}
157239
}
158240

159-
at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
160-
const at::Tensor &indices,
161-
const at::Tensor &offsets,
162-
const at::Tensor &w_scales,
163-
double o_scale, const int64_t mode,
164-
bool include_last_offset) {
241+
template <typename index_t, typename data_t, typename output_t>
242+
void _scaled_embedding_bag_dispatch_dtype(
243+
const at::Tensor &qweight, const at::Tensor &indices,
244+
const at::Tensor &offsets, const at::Tensor &output, int64_t batch_size,
245+
int64_t emb_dim, index_t last_offset, double w_scale, double o_scale) {
246+
data_t *qweight_ptr = qweight.data_ptr<data_t>();
247+
index_t *indices_ptr = indices.data_ptr<index_t>();
248+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
249+
output_t *output_ptr = output.data_ptr<output_t>();
250+
_scaled_embedding_bag<index_t, data_t, output_t>(
251+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
252+
last_offset, w_scale, o_scale);
253+
}
254+
255+
at::Tensor _scaled_embedding_bag_impl(
256+
const at::Tensor &qweight, const at::Tensor &indices,
257+
const at::Tensor &offsets, const at::Tensor &w_scales, double o_scale,
258+
const int64_t mode, bool include_last_offset, at::ScalarType output_dtype) {
165259
// Only support include_last_offset == True and mode ==
166260
// at::native::EmbeddingBagMode::SUM
167261
// TODO: Support more case
@@ -193,32 +287,17 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
193287
int64_t last_offset = indices.numel();
194288

195289
at::Tensor output =
196-
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
197-
if (qtype == c10::ScalarType::Float8_e4m3fn) {
198-
AT_DISPATCH_INDEX_TYPES(
199-
indices.scalar_type(), "_scaled_embedding_bag", [&] {
200-
at::Float8_e4m3fn *qweight_ptr =
201-
qweight.data_ptr<at::Float8_e4m3fn>();
202-
index_t *indices_ptr = indices.data_ptr<index_t>();
203-
index_t *offsets_ptr = offsets.data_ptr<index_t>();
204-
float *output_ptr = output.data_ptr<float>();
205-
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
206-
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207-
emb_dim, last_offset, w_scale, o_scale);
208-
});
209-
} else {
210-
AT_DISPATCH_INDEX_TYPES(
211-
indices.scalar_type(), "_scaled_embedding_bag", [&] {
212-
int8_t *qweight_ptr = qweight.data_ptr<int8_t>();
213-
index_t *indices_ptr = indices.data_ptr<index_t>();
214-
index_t *offsets_ptr = offsets.data_ptr<index_t>();
215-
float *output_ptr = output.data_ptr<float>();
216-
_scaled_embedding_bag<index_t, int8_t>(
217-
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218-
emb_dim, last_offset, w_scale, o_scale);
219-
});
220-
}
221-
290+
at::empty({batch_size, emb_dim}, qweight.options().dtype(output_dtype));
291+
OUTTYPE_DISPATCH(output_dtype, [&] {
292+
QTYPE_DISPATCH(qtype, [&] {
293+
AT_DISPATCH_INDEX_TYPES(
294+
indices.scalar_type(), "_scaled_embedding_bag", [&] {
295+
_scaled_embedding_bag_dispatch_dtype<index_t, data_t, output_t>(
296+
qweight, indices, offsets, output, batch_size, emb_dim,
297+
last_offset, w_scale, o_scale);
298+
});
299+
});
300+
});
222301
return output;
223302
}
224303

torchao/ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor"
7070
)
7171
lib.define(
72-
"_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor"
72+
"_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset, ScalarType output_dtype) -> Tensor"
7373
)
7474
lib.define(
7575
"float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)"
@@ -1118,13 +1118,11 @@ def _(
11181118
o_scale: float,
11191119
mode: int,
11201120
include_last_offset: bool,
1121+
out_dtype: torch.dtype,
11211122
) -> Tensor:
11221123
# Only support include_last_offset == True
11231124
assert include_last_offset == True
11241125
batch_size = offsets.shape[0] - 1
1125-
# Only support out_dtype == torch.float32
1126-
# Next setp: support more out_dtype
1127-
out_dtype = torch.float32
11281126
return qweight.new_empty(batch_size, qweight.shape[1], dtype=out_dtype)
11291127

11301128

0 commit comments

Comments
 (0)