Skip to content

Commit 1ead7ca

Browse files
committed
Align convolve checks with consolidated smeinfo mechanism
Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent 733cd76 commit 1ead7ca

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ Return Value:
938938
--*/
939939
{
940940
// Override
941-
if(SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr &&
941+
if(SMEInfo::IsSMEAvailable && GetMlasPlatform().MlasConvOverride != nullptr &&
942942
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){
943943
return;
944944
}
@@ -1201,7 +1201,7 @@ Return Value:
12011201
--*/
12021202
{
12031203
// Override
1204-
if (SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
1204+
if (SMEInfo::IsSMEAvailable && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
12051205
GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels,
12061206
InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount,
12071207
Activation, WorkingBufferSize, Beta, ThreadPool)){

onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <map>
99
#include <iostream>
1010
#include <algorithm>
11+
#include "mlasi.h"
1112
#include "mlasi_kleidiai.h"
1213
#include <functional>
1314
#include <unordered_map>
@@ -298,7 +299,7 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
298299
const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data,
299300
const float* in_data,
300301
const float* pad_ptr) {
301-
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
302+
size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
302303
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
303304

304305
// Minimize the kernel call count for the number of available threads
@@ -383,8 +384,8 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
383384

384385
const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);
385386

386-
const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
387-
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
387+
const auto m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
388+
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
388389

389390
const auto lhs_ptrs_k = kh * kw;
390391
const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step);
@@ -518,10 +519,10 @@ static void ConvolveSme(const size_t co, //channels out
518519
const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) *
519520
ComputeConvOutSize(iw, d_kw, padding, sw);
520521

521-
size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
522-
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
523-
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
524-
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
522+
size_t n_step = SMEInfo::CanUseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
523+
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
524+
size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
525+
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
525526

526527
// tile iteration dimensions
527528
std::array<size_t,3> dim;
@@ -566,16 +567,16 @@ static void ConvolveSme(const size_t co, //channels out
566567
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];
567568

568569
// Get rhs tile, B
569-
const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
570-
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);
570+
const size_t rhs_packed_offset = SMEInfo::CanUseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
571+
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);
571572

572573
auto BTile = reinterpret_cast<const void*>(
573574
reinterpret_cast<const std::byte*>(rhs.get()) + rhs_packed_offset
574575
);
575576

576577
// Get lhs tile, A
577-
const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
578-
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);
578+
const size_t lhs_packed_offset = SMEInfo::CanUseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
579+
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);
579580

580581
auto ATile = reinterpret_cast<const float*>(
581582
reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset
@@ -589,7 +590,7 @@ static void ConvolveSme(const size_t co, //channels out
589590
MIdx * m_step * co * sizeof(float) +
590591
NIdx * n_step * sizeof(float)];
591592

592-
if (ArmKleidiAI::UseSME2) {
593+
if (SMEInfo::CanUseSME2) {
593594
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
594595
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
595596
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),

0 commit comments

Comments
 (0)