44
55#include " dnnl_extension_utils.h"
66
7- #include " utils/general_utils.h"
87#include < oneapi/dnnl/dnnl.hpp>
98#include " memory_desc/dnnl_blocked_memory_desc.h"
10- #include " onednn/iml_type_mapper.h"
11- #include < common/primitive_desc.hpp>
129#include < common/primitive_desc_iface.hpp>
1310
14- #include < vector>
15-
1611using namespace dnnl ;
1712
1813namespace ov {
1914namespace intel_cpu {
2015
21- uint8_t DnnlExtensionUtils::sizeOfDataType (dnnl:: memory::data_type dataType) {
16+ uint8_t DnnlExtensionUtils::sizeOfDataType (memory::data_type dataType) {
2217 switch (dataType) {
23- case dnnl::memory::data_type::f32 :
24- return 4 ;
25- case dnnl::memory::data_type::s32:
18+ case memory::data_type::f64 :
19+ case memory::data_type::s64:
20+ return 8 ;
21+ case memory::data_type::f32 :
22+ case memory::data_type::s32:
2623 return 4 ;
27- case dnnl::memory::data_type::bf16 :
24+ case memory::data_type::bf16 :
25+ case memory::data_type::f16 :
2826 return 2 ;
29- case dnnl:: memory::data_type::s8:
30- return 1 ;
31- case dnnl:: memory::data_type::u8 :
27+ case memory::data_type::s8:
28+ case memory::data_type:: u8 :
29+ case memory::data_type::bin :
3230 return 1 ;
33- case dnnl::memory::data_type::bin:
34- return 1 ;
35- case dnnl::memory::data_type::f16 :
36- return 2 ;
37- case dnnl::memory::data_type::undef:
31+ case memory::data_type::undef:
3832 return 0 ;
3933 default :
40- IE_THROW () << " Unsupported data type. " ;
34+ IE_THROW () << " Unsupported data type: " << DataTypeToIEPrecision (dataType) ;
4135 }
4236}
4337
4438memory::data_type DnnlExtensionUtils::IEPrecisionToDataType (const InferenceEngine::Precision& prec) {
4539 switch (prec) {
40+ case InferenceEngine::Precision::FP64:
41+ return memory::data_type::f64 ;
42+ case InferenceEngine::Precision::I64:
43+ return memory::data_type::s64;
4644 case InferenceEngine::Precision::FP32:
4745 return memory::data_type::f32 ;
4846 case InferenceEngine::Precision::I32:
@@ -68,6 +66,10 @@ memory::data_type DnnlExtensionUtils::IEPrecisionToDataType(const InferenceEngin
6866
6967InferenceEngine::Precision DnnlExtensionUtils::DataTypeToIEPrecision (memory::data_type dataType) {
7068 switch (dataType) {
69+ case memory::data_type::f64 :
70+ return InferenceEngine::Precision::FP64;
71+ case memory::data_type::s64:
72+ return InferenceEngine::Precision::I64;
7173 case memory::data_type::f32 :
7274 return InferenceEngine::Precision::FP32;
7375 case memory::data_type::s32:
@@ -90,11 +92,11 @@ InferenceEngine::Precision DnnlExtensionUtils::DataTypeToIEPrecision(memory::dat
9092 }
9193}
9294
93- Dim DnnlExtensionUtils::convertToDim (const dnnl:: memory::dim &dim) {
95+ Dim DnnlExtensionUtils::convertToDim (const memory::dim &dim) {
9496 return dim == DNNL_RUNTIME_DIM_VAL ? Shape::UNDEFINED_DIM : static_cast <size_t >(dim);
9597}
96- dnnl:: memory::dim DnnlExtensionUtils::convertToDnnlDim (const Dim &dim) {
97- return dim == Shape::UNDEFINED_DIM ? DNNL_RUNTIME_DIM_VAL : static_cast <dnnl:: memory::dim>(dim);
98+ memory::dim DnnlExtensionUtils::convertToDnnlDim (const Dim &dim) {
99+ return dim == Shape::UNDEFINED_DIM ? DNNL_RUNTIME_DIM_VAL : static_cast <memory::dim>(dim);
98100}
99101
100102VectorDims DnnlExtensionUtils::convertToVectorDims (const memory::dims& dims) {
@@ -133,19 +135,19 @@ memory::format_tag DnnlExtensionUtils::GetPlainFormatByRank(size_t rank) {
133135 }
134136}
135137
136- DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor (const dnnl:: memory::desc &desc) {
138+ DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor (const memory::desc &desc) {
137139 return makeDescriptor (desc.get ());
138140}
139141
140142DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor (const_dnnl_memory_desc_t desc) {
141- if (desc->format_kind == dnnl:: impl::format_kind_t ::dnnl_blocked) {
143+ if (desc->format_kind == impl::format_kind_t ::dnnl_blocked) {
142144 return std::shared_ptr<DnnlBlockedMemoryDesc>(new DnnlBlockedMemoryDesc (desc));
143145 } else {
144146 return std::shared_ptr<DnnlMemoryDesc>(new DnnlMemoryDesc (desc));
145147 }
146148}
147149
148- size_t DnnlExtensionUtils::getMemSizeForDnnlDesc (const dnnl:: memory::desc& desc) {
150+ size_t DnnlExtensionUtils::getMemSizeForDnnlDesc (const memory::desc& desc) {
149151 auto tmpDesc = desc;
150152
151153 const auto offset0 = tmpDesc.get ()->offset0 ;
@@ -167,8 +169,8 @@ std::shared_ptr<DnnlBlockedMemoryDesc> DnnlExtensionUtils::makeUndefinedDesc(con
167169 }
168170}
169171
170- DnnlMemoryDescPtr DnnlExtensionUtils::query_md (const const_dnnl_primitive_desc_t & pd, const dnnl:: query& what, int idx) {
171- auto query = dnnl:: convert_to_c (what);
172+ DnnlMemoryDescPtr DnnlExtensionUtils::query_md (const const_dnnl_primitive_desc_t & pd, const query& what, int idx) {
173+ auto query = convert_to_c (what);
172174 const auto * cdesc = dnnl_primitive_desc_query_md (pd, query, idx);
173175
174176 if (!cdesc)
0 commit comments