88#include < map>
99#include < iostream>
1010#include < algorithm>
11+ #include " mlasi.h"
1112#include " mlasi_kleidiai.h"
1213#include < functional>
1314#include < unordered_map>
@@ -298,7 +299,7 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
298299 const size_t kw, const void * const * lhs_ptrs, std::byte* lhs_data,
299300 const float * in_data,
300301 const float * pad_ptr) {
301- size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
302+ size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
302303 : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
303304
304305 // Minimize the kernel call count for the number of available threads
@@ -383,8 +384,8 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
383384
384385 const auto m = ComputeConvOutSize (ih, kh, padding, sh) * ComputeConvOutSize (iw, kw, padding, sw);
385386
386- const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
387- : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
387+ const auto m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
388+ : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
388389
389390 const auto lhs_ptrs_k = kh * kw;
390391 const auto lhs_ptrs_m = m_step * MlasDivRoundup (m, m_step);
@@ -518,10 +519,10 @@ static void ConvolveSme(const size_t co, //channels out
518519 const auto m = ComputeConvOutSize (ih, d_kh, padding, sh) *
519520 ComputeConvOutSize (iw, d_kw, padding, sw);
520521
521- size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
522- : kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
523- size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
524- : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
522+ size_t n_step = SMEInfo::CanUseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
523+ : kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
524+ size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
525+ : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
525526
526527 // tile iteration dimensions
527528 std::array<size_t ,3 > dim;
@@ -566,16 +567,16 @@ static void ConvolveSme(const size_t co, //channels out
566567 ptrdiff_t NIdx = (tid % (dim[1 ] * dim[2 ])) % dim[2 ];
567568
568569 // Get rhs tile, B
569- const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (NIdx * n_step, d_kh * d_kw, ci)
570- : kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (NIdx * n_step, d_kh * d_kw, ci);
570+ const size_t rhs_packed_offset = SMEInfo::CanUseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (NIdx * n_step, d_kh * d_kw, ci)
571+ : kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (NIdx * n_step, d_kh * d_kw, ci);
571572
572573 auto BTile = reinterpret_cast <const void *>(
573574 reinterpret_cast <const std::byte*>(rhs.get ()) + rhs_packed_offset
574575 );
575576
576577 // Get lhs tile, A
577- const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (MIdx * m_step, d_kh * d_kw, ci)
578- : kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (MIdx * m_step, d_kh * d_kw, ci);
578+ const size_t lhs_packed_offset = SMEInfo::CanUseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (MIdx * m_step, d_kh * d_kw, ci)
579+ : kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (MIdx * m_step, d_kh * d_kw, ci);
579580
580581 auto ATile = reinterpret_cast <const float *>(
581582 reinterpret_cast <const std::byte*>(lhs.get ()) + lhs_packed_offset
@@ -589,7 +590,7 @@ static void ConvolveSme(const size_t co, //channels out
589590 MIdx * m_step * co * sizeof (float ) +
590591 NIdx * n_step * sizeof (float )];
591592
592- if (ArmKleidiAI::UseSME2 ) {
593+ if (SMEInfo::CanUseSME2 ) {
593594 KLEIDIAI_KERNEL_LOG (" kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
594595 kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (
595596 TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof (float ),
0 commit comments