Skip to content

Commit e8ab1b1

Browse files
committed
Update const for kernel interface and sme checks
Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent e0668fc commit e8ab1b1

File tree

5 files changed

+13
-8
lines changed

5 files changed

+13
-8
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include <algorithm>
1616
#include <vector>
1717

18+
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
19+
#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h"
20+
#endif
21+
1822
namespace onnxruntime {
1923
namespace contrib {
2024

@@ -215,7 +219,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
215219

216220
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
217221
// We check that here too before attempting to use them.
218-
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
222+
if (!ArmKleidiAI::SMEInfo::CanUseSME2) {
219223
can_use_dynamic_quant_mlas_ = false;
220224
}
221225

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -862,9 +862,9 @@ Return Value:
862862
--*/
863863
{
864864
// Override
865-
if(GetMlasPlatform().MlasConvOverride != nullptr &&
865+
if(ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr &&
866866
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){
867-
return;
867+
return;
868868
}
869869

870870
const size_t FilterCount = Parameters->FilterCount;
@@ -1101,7 +1101,7 @@ Return Value:
11011101
--*/
11021102
{
11031103
// Override
1104-
if (GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
1104+
if (ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
11051105
GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels,
11061106
InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount,
11071107
Activation, WorkingBufferSize, Beta, ThreadPool)){

onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ struct KaiTlsBuffers {
2727
};
2828
static thread_local KaiTlsBuffers g_kai_tls;
2929

30-
kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel();
31-
kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel();
30+
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel();
31+
const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel();
3232

3333

3434
// Helpers for GEMV

onnxruntime/core/mlas/lib/qgemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ MlasDynamicQgemmPackB(
423423
{
424424
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
425425
//No fallback
426-
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
426+
if (ArmKleidiAI::SMEInfo::CanUseSME2) {//Still require this since no override
427427
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
428428
}
429429
#endif

onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "test_util.h"
1111
#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO
12+
#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h"
1213

1314
class MlasDynamicQgemmTest {
1415
private:
@@ -21,7 +22,7 @@ class MlasDynamicQgemmTest {
2122
public:
2223
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
2324
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
24-
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) {
25+
if (!ArmKleidiAI::SMEInfo::CanUseSME2) {
2526
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test.";
2627
}
2728

0 commit comments

Comments
 (0)