@@ -108,6 +108,7 @@ inline size_t data_type_size(data_type_t data_type) {
108108 case u8 : return sizeof (prec_traits_t <u8 >::type);
109109 case s4: return sizeof (prec_traits_t <s4>::type);
110110 case u4: return sizeof (prec_traits_t <u4>::type);
111+ case u2: return sizeof (prec_traits_t <u2>::type);
111112 case boolean: return sizeof (prec_traits_t <boolean>::type);
112113 case bin: return sizeof (prec_traits_t <u8 >::type);
113114 case nf4: return sizeof (prec_traits_t <u8 >::type);
@@ -124,6 +125,7 @@ inline size_t elements_to_bytes(data_type_t data_type, size_t count) {
124125 case f4_e3m0:
125126 case s4:
126127 case u4: return (count + 1 ) >> 1 ;
128+ case u2: return (count + 3 ) >> 2 ;
127129 default : return data_type_size (data_type) * count;
128130 }
129131}
@@ -135,6 +137,7 @@ inline size_t bytes_to_elements(data_type_t data_type, size_t bytes) {
135137 case f4_e3m0:
136138 case s4:
137139 case u4: return bytes * 2 ;
140+ case u2: return bytes * 4 ;
138141 default : return utils::div_up (bytes, data_type_size (data_type));
139142 }
140143}
@@ -453,7 +456,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
453456
454457 /* prop_kind doesn't matter */
455458 if (everyone_is (f32 , src_dt, wei_dt)) return f32 ;
456- if (one_of (src_dt, f32 , bf16 ) && one_of (wei_dt, u8 , s8, nf4, s4, u4, f4_e2m1)) return f32 ;
459+ if (one_of (src_dt, f32 , bf16 ) && one_of (wei_dt, u8 , s8, nf4, s4, u4, f4_e2m1, u2 )) return f32 ;
457460 if (everyone_is (f64 , src_dt, wei_dt)) return f64 ;
458461
459462 if (one_of (prop_kind, forward_training, forward_inference)) {
@@ -1301,7 +1304,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
13011304
13021305 bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
13031306 && utils::one_of (data_type, f4_e3m0, f4_e2m1, e8m0, f8_e5m2,
1304- f8_e4m3, f16 , bf16 , f32 , f64 , s32, s8, u8 , s4, u4, bin, nf4);
1307+ f8_e4m3, f16 , bf16 , f32 , f64 , s32, s8, u8 , s4, u4, u2, bin, nf4);
13051308 if (!ok) return false ;
13061309
13071310 bool has_runtime_dims = false ;
0 commit comments