Skip to content

Commit 0cad963

Browse files
committed
[FORK][CPU][FEATURE] InnerProduct primitive: u2 weights decompression
1 parent 5b93a37 commit 0cad963

24 files changed

+494
-81
lines changed

include/oneapi/dnnl/dnnl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,8 @@ struct memory : public handle<dnnl_memory_t> {
907907
s4 = dnnl_s4,
908908
/// 4-bit unsigned integer.
909909
u4 = dnnl_u4,
910+
/// 2-bit unsigned integer.
911+
u2 = dnnl_u2,
910912
/// 1-bit integer
911913
bin = dnnl_bin,
912914
/// 4-bit normalized float.

include/oneapi/dnnl/dnnl_common_types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ typedef enum {
112112
dnnl_nf4 = 16,
113113
/// 1-bit integer.
114114
dnnl_bin = 17,
115+
/// 2-bit unsigned integer.
116+
dnnl_u2 = 18,
115117

116118
/// Parameter to allow internal only data_types without undefined behavior.
117119
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.

src/common/c_types_map.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ const data_type_t s8 = dnnl_s8;
176176
const data_type_t u8 = dnnl_u8;
177177
const data_type_t s4 = dnnl_s4;
178178
const data_type_t u4 = dnnl_u4;
179+
const data_type_t u2 = dnnl_u2;
179180
const data_type_t boolean = dnnl_boolean;
180181
const data_type_t data_type_max = dnnl_data_type_max;
181182

src/common/dnnl_traits.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ template <> struct prec_traits_t<data_type::nf4> {
107107
using type = uint8_t;
108108
};
109109

110+
template <> struct prec_traits_t<data_type::u2> {
111+
using type = uint8_t;
112+
};
113+
110114
template <>
111115
struct data_traits_t<float4_e3m0_t> {
112116
static constexpr data_type_t data_type = data_type::f4_e3m0;

src/common/inner_product.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
112112
const data_type_t src_dt = desc.src_desc.data_type;
113113
const data_type_t wei_dt = desc.weights_desc.data_type;
114114
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) &&
115-
one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) ||
115+
one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1, data_type::u2)) ||
116116
(one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16));
117117
auto attr_mask = smask_t::none;
118118
// From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled
@@ -142,7 +142,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
142142
data_type::s32);
143143

144144
if (engine->kind() == engine_kind::cpu)
145-
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1);
145+
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1, data_type::u2);
146146
if (is_int8) fwd_attr_mask |= smask_t::scales | smask_t::zero_points | smask_t::src_dyn_quant_params;
147147

148148
if (is_weight_compression) {

src/common/memory_desc_wrapper.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ struct memory_desc_wrapper : public c_compatible {
157157
* For the rest data types returns 1. */
158158
size_t sub_byte_data_type_multiplier() const {
159159
if (utils::one_of(data_type(), data_type::s4, data_type::u4)) return 2;
160+
if (data_type() == data_type::u2) return 4;
160161
return 1;
161162
}
162163

src/common/type_helpers.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ inline size_t data_type_size(data_type_t data_type) {
108108
case u8: return sizeof(prec_traits_t<u8>::type);
109109
case s4: return sizeof(prec_traits_t<s4>::type);
110110
case u4: return sizeof(prec_traits_t<u4>::type);
111+
case u2: return sizeof(prec_traits_t<u2>::type);
111112
case boolean: return sizeof(prec_traits_t<boolean>::type);
112113
case bin: return sizeof(prec_traits_t<u8>::type);
113114
case nf4: return sizeof(prec_traits_t<u8>::type);
@@ -124,6 +125,7 @@ inline size_t elements_to_bytes(data_type_t data_type, size_t count) {
124125
case f4_e3m0:
125126
case s4:
126127
case u4: return (count + 1) >> 1;
128+
case u2: return (count + 3) >> 2;
127129
default: return data_type_size(data_type) * count;
128130
}
129131
}
@@ -135,6 +137,7 @@ inline size_t bytes_to_elements(data_type_t data_type, size_t bytes) {
135137
case f4_e3m0:
136138
case s4:
137139
case u4: return bytes * 2;
140+
case u2: return bytes * 4;
138141
default: return utils::div_up(bytes, data_type_size(data_type));
139142
}
140143
}
@@ -453,7 +456,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
453456

454457
/* prop_kind doesn't matter */
455458
if (everyone_is(f32, src_dt, wei_dt)) return f32;
456-
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) return f32;
459+
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1, u2)) return f32;
457460
if (everyone_is(f64, src_dt, wei_dt)) return f64;
458461

459462
if (one_of(prop_kind, forward_training, forward_inference)) {
@@ -1301,7 +1304,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
13011304

13021305
bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
13031306
&& utils::one_of(data_type, f4_e3m0, f4_e2m1, e8m0, f8_e5m2,
1304-
f8_e4m3, f16, bf16, f32, f64, s32, s8, u8, s4, u4, bin, nf4);
1307+
f8_e4m3, f16, bf16, f32, f64, s32, s8, u8, s4, u4, u2, bin, nf4);
13051308
if (!ok) return false;
13061309

13071310
bool has_runtime_dims = false;

src/cpu/cpu_inner_product_list.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
9999
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
100100
nullptr,
101101
}},
102+
{{forward, f32, u2, f32}, {
103+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni)
104+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
105+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2_vnni)
106+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
107+
nullptr,
108+
}},
102109
{{forward, f32, f16, f32}, {
103110
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
104111
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
@@ -188,6 +195,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
188195
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
189196
nullptr,
190197
}},
198+
{{forward, bf16, u2, f32}, {
199+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
200+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
201+
nullptr,
202+
}},
203+
{{forward, bf16, u2, bf16}, {
204+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
205+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
206+
nullptr,
207+
}},
191208
{{forward, f16, f16, f32}, {
192209
//CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
193210
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx_fp16)

src/cpu/cpu_primitive.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
= ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | (arg)); \
8585
VCHECK_ATTR(utils::one_of(zero_points_d.data_type(), \
8686
data_type::s32, data_type::s8, data_type::u8, \
87-
data_type::s4, data_type::u4, data_type::f32), \
87+
data_type::s4, data_type::u4, data_type::u2, data_type::f32), \
8888
VERBOSE_INVALID_DATATYPE, "zero points"); \
8989
} \
9090
} \

src/cpu/reorder/cpu_reorder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ regular_impl_list_map() {
5151
{{f32, u4, 0}, &regular_u4_impl_list_map()},
5252
{{s4, data_type::undef, 0}, &regular_s4_impl_list_map()},
5353
{{u4, data_type::undef, 0}, &regular_u4_impl_list_map()},
54+
{{u2, data_type::undef, 0}, &regular_u2_impl_list_map()},
5455
{{bin, data_type::undef, 0}, &regular_bin_impl_list_map()},
5556
{{nf4, data_type::undef, 0}, &regular_nf4_impl_list_map()},
5657
{{s4, f32, 0}, &regular_s4_impl_list_map()},

0 commit comments

Comments
 (0)