@@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) {
2222 delete ptr;
2323}
2424
25+ DNNLScratchPadManager::DNNLScratchPadManager () : size_(0 ), ptr_(nullptr ) {
26+ this ->realloc (allocation_unit * 128 );
27+ }
28+
29+ void DNNLScratchPadManager::realloc (size_t new_size) {
30+ new_size = round (new_size);
31+ if (new_size > size_) {
32+ ptr_ = std::aligned_alloc (64 , new_size);
33+ size_ = new_size;
34+ }
35+ }
36+
37+ DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager () {
38+ static DNNLScratchPadManager manager;
39+ return &manager;
40+ }
41+
2542template <typename KT, typename VT>
2643class DNNLPrimitiveCache {
2744 public:
@@ -166,6 +183,23 @@ struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
166183 hash<int >()(static_cast <int >(val.bias_type ));
167184 }
168185};
186+
187+ template <>
188+ struct hash <MatMulPrimitiveHandler::ClassMatmulCacheKey> {
189+ size_t operator ()(
190+ const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
191+ return hash<dnnl_dim_t >()(val.b_n_size ) ^ hash<dnnl_dim_t >()(val.b_k_size );
192+ }
193+ };
194+
195+ template <>
196+ struct hash <MatMulPrimitiveHandler::MSizeCacheKey> {
197+ size_t operator ()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const {
198+ return hash<dnnl_dim_t >()(val.a_m_size ) ^
199+ hash<dnnl_dim_t >()(val.a_m_stride ) ^ hash<bool >()(val.use_bias ) ^
200+ hash<int >()(static_cast <int >(val.bias_type ));
201+ }
202+ };
169203} // namespace std
170204
171205bool operator ==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
@@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
181215 l.bias_type == r.bias_type ;
182216}
183217
218+ bool operator ==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
219+ const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
220+ return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size ;
221+ }
222+
223+ bool operator ==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
224+ const MatMulPrimitiveHandler::MSizeCacheKey& r) {
225+ return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride &&
226+ l.use_bias == r.use_bias && l.bias_type == r.bias_type ;
227+ }
228+
184229static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
185230get_w8a8_class_primitive_cache (
186231 const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
@@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
239284 }
240285
241286 dnnl::matmul matmul = get_matmul_cache (args);
287+
288+ auto && [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr (5 );
289+ scratchpad_storage->set_data_handle (
290+ DNNLScratchPadManager::get_dnnl_scratchpad_manager ()->get_data <void >());
291+
242292 matmul.execute (default_stream (), memory_cache_);
243293 default_stream ().wait ();
244294}
@@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
257307
258308 return m_size_cache_->get_or_create (key, [&]() {
259309 dnnl::matmul::primitive_desc desc = this ->create_primitive_desc (key, false );
310+ auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager ();
311+ manager->realloc (desc.scratchpad_desc ().get_size ());
260312 return dnnl::matmul (desc);
261313 });
262314}
@@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
300352 dnnl::memory ({{b_n_size_}, dnnl::memory::data_type::f32 , {1 }},
301353 default_engine (), nullptr );
302354 set_runtime_memory_ptr (4 , memory_cache_[DNNL_ARG_BIAS].get ());
355+
356+ memory_cache_[DNNL_ARG_SCRATCHPAD] =
357+ dnnl::memory ({{b_n_size_}, dnnl::memory::data_type::f32 , {1 }},
358+ default_engine (), nullptr );
359+ set_runtime_memory_ptr (5 , memory_cache_[DNNL_ARG_SCRATCHPAD].get ());
303360}
304361
305362dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc (
@@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
319376 dnnl::memory::format_tag::ab);
320377
321378 dnnl::primitive_attr attr;
379+
380+ attr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
381+
322382 // For PER_TOKEN, scales will be applied in outside epilogue
323383 if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
324384 attr.set_scales_mask (DNNL_ARG_SRC, 0 );
@@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
344404 attr);
345405 }
346406}
407+
408+ MatMulPrimitiveHandler::MatMulPrimitiveHandler (const Args& args)
409+ : DNNLMatMulPrimitiveHandler(
410+ static_cast <DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
411+ m_size_cache_(nullptr ) {
412+ assert (ab_type_ == dnnl::memory::data_type::f32 ||
413+ ab_type_ == dnnl::memory::data_type::bf16 ||
414+ ab_type_ == dnnl::memory::data_type::f16 );
415+ prepack_weight (args.b_ptr ,
416+ create_primitive_desc (
417+ MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
418+ .a_m_stride = DNNL_RUNTIME_DIM_VAL,
419+ .use_bias = false ,
420+ .bias_type = dnnl::memory::data_type::undef},
421+ true )
422+ .weights_desc ());
423+ init_runtime_memory_cache (args);
424+ }
425+
426+ static std::shared_ptr<MatMulPrimitiveHandler::MSizeCache>
427+ get_matul_class_primitive_cache (
428+ const MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
429+ int64_t cache_size) {
430+ static MatMulPrimitiveHandler::ClassMatmulCache cache (128 );
431+ assert (cache_size > 0 );
432+ return cache.get_or_create (key, [&]() {
433+ return std::make_shared<MatMulPrimitiveHandler::MSizeCache>(cache_size);
434+ });
435+ }
436+
437+ void MatMulPrimitiveHandler::execute (ExecArgs& args) {
438+ auto && [a_storage, a_mem_desc] = get_runtime_memory_ptr (0 );
439+ auto && [c_storage, c_mem_desc] = get_runtime_memory_ptr (1 );
440+ a_storage->set_data_handle ((void *)args.a_ptr );
441+ a_mem_desc->dims [0 ] = args.a_m_size ;
442+ a_mem_desc->format_desc .blocking .strides [0 ] = args.a_m_stride ;
443+ c_storage->set_data_handle ((void *)args.c_ptr );
444+ c_mem_desc->dims [0 ] = args.a_m_size ;
445+
446+ if (args.use_bias ) {
447+ auto && [bias_storage, bias_mem_desc] = get_runtime_memory_ptr (2 );
448+ bias_storage->set_data_handle ((void *)args.bias_ptr );
449+ }
450+
451+ dnnl::matmul matmul = get_matmul_cache (args);
452+
453+ auto && [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr (3 );
454+ scratchpad_storage->set_data_handle (
455+ DNNLScratchPadManager::get_dnnl_scratchpad_manager ()->get_data <void >());
456+
457+ matmul.execute (default_stream (), memory_cache_);
458+ default_stream ().wait ();
459+ }
460+
461+ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache (
462+ const MSizeCacheKey& key) {
463+ if (m_size_cache_.get () == nullptr ) {
464+ ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
465+ m_size_cache_ = get_matul_class_primitive_cache (key, primitive_cache_size_);
466+ }
467+ return m_size_cache_->get_or_create (key, [&]() {
468+ dnnl::matmul::primitive_desc desc = this ->create_primitive_desc (key, false );
469+ auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager ();
470+ manager->realloc (desc.scratchpad_desc ().get_size ());
471+ return dnnl::matmul (desc);
472+ });
473+ }
474+
475+ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc (
476+ const MSizeCacheKey& key, bool first_time) {
477+ dnnl::memory::desc a_md;
478+ dnnl::memory::desc b_md;
479+ if (first_time) {
480+ a_md = dnnl::memory::desc ({key.a_m_size , b_k_size_}, b_type_,
481+ dnnl::memory::format_tag::ab);
482+ b_md = dnnl::memory::desc ({b_k_size_, b_n_size_}, b_type_,
483+ dnnl::memory::format_tag::any);
484+ } else {
485+ a_md = dnnl::memory::desc ({key.a_m_size , b_k_size_}, b_type_,
486+ {key.a_m_stride , 1 });
487+ b_md = b_target_mem_desc_;
488+ }
489+ dnnl::memory::desc c_md ({key.a_m_size , b_n_size_}, c_type_,
490+ dnnl::memory::format_tag::ab);
491+
492+ dnnl::primitive_attr attr;
493+ attr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
494+
495+ if (key.use_bias ) {
496+ dnnl::memory::desc bias_md ({1 , b_n_size_}, key.bias_type , {b_n_size_, 1 });
497+ return dnnl::matmul::primitive_desc (default_engine (), a_md, b_md, bias_md,
498+ c_md, attr);
499+ } else {
500+ return dnnl::matmul::primitive_desc (default_engine (), a_md, b_md, c_md,
501+ attr);
502+ }
503+ }
504+
505+ void MatMulPrimitiveHandler::init_runtime_memory_cache (const Args& args) {
506+ memory_cache_[DNNL_ARG_SRC] = dnnl::memory (
507+ {{1 , b_k_size_}, b_type_, {b_k_size_, 1 }}, default_engine (), nullptr );
508+ set_runtime_memory_ptr (0 , memory_cache_[DNNL_ARG_SRC].get ());
509+ memory_cache_[DNNL_ARG_DST] =
510+ dnnl::memory ({{1 , b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
511+ default_engine (), nullptr );
512+ set_runtime_memory_ptr (1 , memory_cache_[DNNL_ARG_DST].get ());
513+
514+ memory_cache_[DNNL_ARG_BIAS] =
515+ dnnl::memory ({{b_n_size_}, dnnl::memory::data_type::f32 , {1 }},
516+ default_engine (), nullptr );
517+ set_runtime_memory_ptr (2 , memory_cache_[DNNL_ARG_BIAS].get ());
518+
519+ memory_cache_[DNNL_ARG_SCRATCHPAD] =
520+ dnnl::memory ({{b_n_size_}, dnnl::memory::data_type::f32 , {1 }},
521+ default_engine (), nullptr );
522+ set_runtime_memory_ptr (3 , memory_cache_[DNNL_ARG_SCRATCHPAD].get ());
523+ }
0 commit comments