Skip to content

Commit b5519d3

Browse files
authored
[CPU] enable brdgmm_dw_conv kernels with non f32 bias (#28076)
### Details: - *brdgmm_dw_conv kernels support only bia_type the same as src_type or dst_type* ### Tickets: - *CVS-157009*
1 parent 200b67a commit b5519d3

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/plugins/intel_cpu/src/nodes/conv.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,31 @@ void Convolution::createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
980980
memory::data_type bdt = outDnnlDesc.get_data_type();
981981
#else
982982
memory::data_type bdt = memory::data_type::f32;
983+
/* brdgmm_dw_conv has more perf gain on bf16/fp16 inference.
984+
brdgmm_dw_conv supports only bia_type the same as src_type or dst_type.
985+
dw convolution support in onednn 3.5.
986+
BF16:
987+
kernel type | brgdconv | jit_uni_dw_convolution_fwd_t
988+
support impl type | native bf16 ISA without AMX | avx512_core_bf16 or avx512_core
989+
bias dt | oneof(src,dest) | oneof(src, dest, f32)
990+
FP16:
991+
kernel type | brgdconv | brgemm_convolution_fwd_t
992+
impl type | native FP16 ISA without AMX | native FP16 ISA
993+
bias type | oneof(src,dest) | oneof(src, dest, f32)
994+
@todo: this bias type changes may have minor accuracy impact on some models, so when upstream ONEDNN extend this
995+
kind of matrix support (ticket MFDNN-12936) we can continue use bdt = memory::data_type::f32 here;
996+
*/
997+
auto out_dt = outDnnlDesc.get_data_type();
998+
if (!canBeExecutedInInt8() && isDepthWise()) {
999+
bool isF16BiasSupported = (out_dt == memory::data_type::f16) && hasHardwareSupport(ov::element::f16);
1000+
bool isBF16BiasSupported = (out_dt == memory::data_type::bf16) &&
1001+
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) ||
1002+
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2));
1003+
1004+
if (isF16BiasSupported || isBF16BiasSupported) {
1005+
bdt = out_dt;
1006+
}
1007+
}
9831008
#endif
9841009
biasDnnlDesc =
9851010
dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(expectedBiasDims), bdt, memory::format_tag::any);

0 commit comments

Comments
 (0)