66#include < c10/util/Unroll.h>
77#include < torch/all.h>
88
9+ #define QTYPE_DISPATCH (TYPE, ...) \
10+ [&]() { \
11+ switch (TYPE) { \
12+ case c10::ScalarType::Float8_e4m3fn: { \
13+ using data_t = at::Float8_e4m3fn; \
14+ return __VA_ARGS__ (); \
15+ } \
16+ case c10::ScalarType::Char: { \
17+ using data_t = int8_t ; \
18+ return __VA_ARGS__ (); \
19+ } \
20+ default : \
21+ TORCH_CHECK (false , " scaled_embeding_bag: unsupport qtype" ); \
22+ } \
23+ }()
24+
25+ #define OUTTYPE_DISPATCH (TYPE, ...) \
26+ [&]() { \
27+ switch (TYPE) { \
28+ case c10::ScalarType::Float: { \
29+ using output_t = float ; \
30+ return __VA_ARGS__ (); \
31+ } \
32+ case c10::ScalarType::Char: { \
33+ using output_t = int8_t ; \
34+ return __VA_ARGS__ (); \
35+ } \
36+ default : \
37+ TORCH_CHECK (false , " scaled_embeding_bag: unsupport output type" ); \
38+ } \
39+ }()
40+
941namespace torchao {
1042
1143namespace {
@@ -53,14 +85,71 @@ static inline CHUNK load_chunk(const int8_t *x) {
5385 x7 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 3 ));
5486 return {x0, x1, x2, x3, x4, x5, x6, x7};
5587}
88+
89+ static inline void store_chunk (float *output, CHUNK chunk) {
90+ __m512 x0, x1, x2, x3, x4, x5, x6, x7;
91+ std::tie (x0, x1, x2, x3, x4, x5, x6, x7) = chunk;
92+ _mm512_store_ps (output, x0);
93+ _mm512_store_ps (output + 16 , x1);
94+ _mm512_store_ps (output + 32 , x2);
95+ _mm512_store_ps (output + 48 , x3);
96+ _mm512_store_ps (output + 64 , x4);
97+ _mm512_store_ps (output + 80 , x5);
98+ _mm512_store_ps (output + 96 , x6);
99+ _mm512_store_ps (output + 112 , x7);
100+ }
101+
102+ static inline void store_chunk (int8_t *output, CHUNK chunk) {
103+ __m512i x00, x64;
104+ __m512i y0, y1, y2, y3, y4, y5, y6, y7;
105+ __m512 f0, f1, f2, f3, f4, f5, f6, f7;
106+ std::tie (f0, f1, f2, f3, f4, f5, f6, f7) = chunk;
107+ y0 = _mm512_cvt_roundps_epi32 (
108+ f0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
109+ y1 = _mm512_cvt_roundps_epi32 (
110+ f1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
111+ y2 = _mm512_cvt_roundps_epi32 (
112+ f2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
113+ y3 = _mm512_cvt_roundps_epi32 (
114+ f3, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
115+ y4 = _mm512_cvt_roundps_epi32 (
116+ f4, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
117+ y5 = _mm512_cvt_roundps_epi32 (
118+ f5, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
119+ y6 = _mm512_cvt_roundps_epi32 (
120+ f6, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
121+ y7 = _mm512_cvt_roundps_epi32 (
122+ f7, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
123+ x00 = _mm512_inserti32x4 (x00, _mm512_cvtsepi32_epi8 (y0), 0 );
124+ x00 = _mm512_inserti32x4 (x00, _mm512_cvtsepi32_epi8 (y1), 1 );
125+ x00 = _mm512_inserti32x4 (x00, _mm512_cvtsepi32_epi8 (y2), 2 );
126+ x00 = _mm512_inserti32x4 (x00, _mm512_cvtsepi32_epi8 (y3), 3 );
127+ x64 = _mm512_inserti32x4 (x64, _mm512_cvtsepi32_epi8 (y4), 0 );
128+ x64 = _mm512_inserti32x4 (x64, _mm512_cvtsepi32_epi8 (y5), 1 );
129+ x64 = _mm512_inserti32x4 (x64, _mm512_cvtsepi32_epi8 (y6), 2 );
130+ x64 = _mm512_inserti32x4 (x64, _mm512_cvtsepi32_epi8 (y7), 3 );
131+ _mm512_store_si512 (output, x00);
132+ _mm512_store_si512 (output + 64 , x64);
133+ }
56134#endif
57135
58- template <typename index_t , typename data_t >
136+ static inline void store_elem (float &out, float input) {
137+ out = input;
138+ }
139+
140+ static inline void store_elem (int8_t &out, float input) {
141+ float rounded = std::round (input);
142+ float clamped = std::max (-128 .0f , std::min (127 .0f , rounded));
143+ int32_t int32_value = static_cast <int32_t >(clamped);
144+ out = static_cast <int8_t >(int32_value);
145+ }
146+
147+ template <typename index_t , typename data_t , typename output_t >
59148inline void _scaled_embedding_bag_krnl (
60149 const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
61150 const int64_t emb_dim, const index_t last_offset, const index_t *indices,
62151 const index_t *offsets, const data_t *weight, const double scale,
63- float *result, const int64_t num_batch) {
152+ output_t *result, const int64_t num_batch) {
64153#if defined(CPU_CAPABILITY_AVX512)
65154 if (emb_dim % 128 == 0 ) {
66155 constexpr int64_t block_dim = 128 ;
@@ -76,7 +165,7 @@ inline void _scaled_embedding_bag_krnl(
76165 for (int64_t block_id = 0 ; block_id < num_blocks; block_id++) {
77166 // load first indices
78167 int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
79- float *block_result = result + block_dim * block_id;
168+ output_t *block_result = result + block_dim * block_id;
80169 std::tie (x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk (weight + idx);
81170 for (int64_t j = start_idx + 1 ; j < end_idx; ++j) {
82171 // add following idx
@@ -100,14 +189,7 @@ inline void _scaled_embedding_bag_krnl(
100189 x6 = _mm512_mul_ps (x6, scale_v);
101190 x7 = _mm512_mul_ps (x7, scale_v);
102191 // store
103- _mm512_store_ps (block_result, x0);
104- _mm512_store_ps (block_result + 16 , x1);
105- _mm512_store_ps (block_result + 32 , x2);
106- _mm512_store_ps (block_result + 48 , x3);
107- _mm512_store_ps (block_result + 64 , x4);
108- _mm512_store_ps (block_result + 80 , x5);
109- _mm512_store_ps (block_result + 96 , x6);
110- _mm512_store_ps (block_result + 112 , x7);
192+ store_chunk (block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
111193 }
112194 result += num_emb * emb_dim;
113195 }
@@ -127,14 +209,14 @@ inline void _scaled_embedding_bag_krnl(
127209 value += float (weight[idx + d]);
128210 }
129211 value = value * scale;
130- result[d] = value;
212+ store_elem ( result[d], value) ;
131213 }
132214 result += num_emb * emb_dim;
133215 }
134216}
135217
136- template <typename index_t , typename data_t >
137- void _scaled_embedding_bag (float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
218+ template <typename index_t , typename data_t , typename output_t >
219+ void _scaled_embedding_bag (output_t *o_ptr, data_t *w_ptr, index_t *indices_ptr,
138220 index_t *offsets_ptr, int64_t num_batch,
139221 int64_t emb_dim, index_t last_offset, double w_scale,
140222 double o_scale) {
@@ -147,7 +229,7 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
147229 for (int64_t n = 0 ; n < num_emb; ++n) {
148230 const int64_t bs_begin = b * b_block;
149231 const int64_t bs_end = std::min (num_batch, (b + 1 ) * b_block);
150- float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
232+ output_t *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
151233 // avoid offsets not include last batch
152234 _scaled_embedding_bag_krnl (bs_begin, bs_end, num_emb, emb_dim,
153235 last_offset, indices_ptr, offsets_ptr, w_ptr,
@@ -156,12 +238,24 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
156238 }
157239}
158240
159- at::Tensor _scaled_embedding_bag_impl (const at::Tensor &qweight,
160- const at::Tensor &indices,
161- const at::Tensor &offsets,
162- const at::Tensor &w_scales,
163- double o_scale, const int64_t mode,
164- bool include_last_offset) {
241+ template <typename index_t , typename data_t , typename output_t >
242+ void _scaled_embedding_bag_dispatch_dtype (
243+ const at::Tensor &qweight, const at::Tensor &indices,
244+ const at::Tensor &offsets, const at::Tensor &output, int64_t batch_size,
245+ int64_t emb_dim, index_t last_offset, double w_scale, double o_scale) {
246+ data_t *qweight_ptr = qweight.data_ptr <data_t >();
247+ index_t *indices_ptr = indices.data_ptr <index_t >();
248+ index_t *offsets_ptr = offsets.data_ptr <index_t >();
249+ output_t *output_ptr = output.data_ptr <output_t >();
250+ _scaled_embedding_bag<index_t , data_t , output_t >(
251+ output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
252+ last_offset, w_scale, o_scale);
253+ }
254+
255+ at::Tensor _scaled_embedding_bag_impl (
256+ const at::Tensor &qweight, const at::Tensor &indices,
257+ const at::Tensor &offsets, const at::Tensor &w_scales, double o_scale,
258+ const int64_t mode, bool include_last_offset, at::ScalarType output_dtype) {
165259 // Only support include_last_offset == True and mode ==
166260 // at::native::EmbeddingBagMode::SUM
167261 // TODO: Support more case
@@ -193,32 +287,17 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
193287 int64_t last_offset = indices.numel ();
194288
195289 at::Tensor output =
196- at::empty ({batch_size, emb_dim}, qweight.options ().dtype (at::kFloat ));
197- if (qtype == c10::ScalarType::Float8_e4m3fn) {
198- AT_DISPATCH_INDEX_TYPES (
199- indices.scalar_type (), " _scaled_embedding_bag" , [&] {
200- at::Float8_e4m3fn *qweight_ptr =
201- qweight.data_ptr <at::Float8_e4m3fn>();
202- index_t *indices_ptr = indices.data_ptr <index_t >();
203- index_t *offsets_ptr = offsets.data_ptr <index_t >();
204- float *output_ptr = output.data_ptr <float >();
205- _scaled_embedding_bag<index_t , at::Float8_e4m3fn>(
206- output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207- emb_dim, last_offset, w_scale, o_scale);
208- });
209- } else {
210- AT_DISPATCH_INDEX_TYPES (
211- indices.scalar_type (), " _scaled_embedding_bag" , [&] {
212- int8_t *qweight_ptr = qweight.data_ptr <int8_t >();
213- index_t *indices_ptr = indices.data_ptr <index_t >();
214- index_t *offsets_ptr = offsets.data_ptr <index_t >();
215- float *output_ptr = output.data_ptr <float >();
216- _scaled_embedding_bag<index_t , int8_t >(
217- output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218- emb_dim, last_offset, w_scale, o_scale);
219- });
220- }
221-
290+ at::empty ({batch_size, emb_dim}, qweight.options ().dtype (output_dtype));
291+ OUTTYPE_DISPATCH (output_dtype, [&] {
292+ QTYPE_DISPATCH (qtype, [&] {
293+ AT_DISPATCH_INDEX_TYPES (
294+ indices.scalar_type (), " _scaled_embedding_bag" , [&] {
295+ _scaled_embedding_bag_dispatch_dtype<index_t , data_t , output_t >(
296+ qweight, indices, offsets, output, batch_size, emb_dim,
297+ last_offset, w_scale, o_scale);
298+ });
299+ });
300+ });
222301 return output;
223302}
224303
0 commit comments