Skip to content

Commit 4cf6ccd

Browse files
committed
Update const for kernel interface and sme checks
Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent 82480ad commit 4cf6ccd

File tree

5 files changed

+13
-8
lines changed

5 files changed

+13
-8
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include <algorithm>
1616
#include <vector>
1717

18+
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
19+
#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h"
20+
#endif
21+
1822
namespace onnxruntime {
1923
namespace contrib {
2024

@@ -215,7 +219,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
215219

216220
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
217221
// We check that here too before attempting to use them.
218-
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) {
222+
if (!ArmKleidiAI::SMEInfo::CanUseSME2) {
219223
can_use_dynamic_quant_mlas_ = false;
220224
}
221225

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,9 +938,9 @@ Return Value:
938938
--*/
939939
{
940940
// Override
941-
if(GetMlasPlatform().MlasConvOverride != nullptr &&
941+
if(ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr &&
942942
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){
943-
return;
943+
return;
944944
}
945945

946946
const size_t FilterCount = Parameters->FilterCount;
@@ -1201,7 +1201,7 @@ Return Value:
12011201
--*/
12021202
{
12031203
// Override
1204-
if (GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
1204+
if (ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
12051205
GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels,
12061206
InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount,
12071207
Activation, WorkingBufferSize, Beta, ThreadPool)){

onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ struct KaiTlsBuffers {
2727
};
2828
static thread_local KaiTlsBuffers g_kai_tls;
2929

30-
kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel();
31-
kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel();
30+
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel();
31+
const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel();
3232

3333

3434
// Helpers for GEMV

onnxruntime/core/mlas/lib/qgemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ Return Value:
407407
~(BufferAlignment - 1);
408408
// If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
409409
// this use case. Concat both packed representations for later decision. This allows for cases later
410-
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
410+
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
411411
// for better performance
412412
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
413413
}
@@ -425,7 +425,7 @@ MlasDynamicQgemmPackB(
425425
{
426426
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
427427
//No fallback
428-
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
428+
if (ArmKleidiAI::SMEInfo::CanUseSME2) {//Still require this since no override
429429
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
430430
}
431431
#endif

onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "test_util.h"
1111
#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO
12+
#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h"
1213

1314
class MlasDynamicQgemmTest {
1415
private:

0 commit comments

Comments
 (0)