Skip to content

Commit 1d9b7c8

Browse files
committed
Modify SME Detection struct location and logic
Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent 0504a85 commit 1d9b7c8

File tree

8 files changed

+44
-30
lines changed

8 files changed

+44
-30
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME()
4+
#include "core/common/cpuid_info.h"
55
#include "core/common/narrow.h"
66
#include "core/common/safeint.h"
77
#include "core/mlas/inc/mlas.h"
@@ -219,7 +219,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
219219

220220
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
221221
// We check that here too before attempting to use them.
222-
if (!ArmKleidiAI::SMEInfo::CanUseSME2) {
222+
if (!SMEInfo::CanUseSME2) {
223223
can_use_dynamic_quant_mlas_ = false;
224224
}
225225

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ Return Value:
938938
--*/
939939
{
940940
// Override
941-
if(ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr &&
941+
if(SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr &&
942942
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){
943943
return;
944944
}
@@ -1201,7 +1201,7 @@ Return Value:
12011201
--*/
12021202
{
12031203
// Override
1204-
if (ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
1204+
if (SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
12051205
GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels,
12061206
InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount,
12071207
Activation, WorkingBufferSize, Beta, ThreadPool)){
@@ -1411,8 +1411,8 @@ Return Value:
14111411

14121412
if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {
14131413

1414-
size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
1415-
Parameters->FilterCount * Parameters->OutputSize,
1414+
size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
1415+
Parameters->FilterCount * Parameters->OutputSize,
14161416
static_cast<size_t>(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)});
14171417
TargetThreadCount = MaximumThreadCount;
14181418
if (static_cast<size_t>(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) {

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() {
139139
}
140140

141141
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() {
142-
if (ArmKleidiAI::SMEInfo::CanUseSME2) {
142+
if (SMEInfo::CanUseSME2) {
143143
return sgemm_gemm_sme2;
144144
} else {
145145
return sgemm_gemm_sme;
146146
}
147147
}
148148

149149
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() {
150-
if (ArmKleidiAI::SMEInfo::CanUseSME2) {
150+
if (SMEInfo::CanUseSME2) {
151151
return sgemm_gemv_sme2;
152152
} else {
153153
return sgemm_gemv_sme;

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#pragma once
88

9-
#include "mlasi.h"
9+
#include "../mlasi.h"
1010
#include <iostream>
1111

1212
// Fix to ensure compatibility with MSVC build
@@ -51,21 +51,6 @@
5151

5252
namespace ArmKleidiAI {
5353

54-
struct SMEInfo {
55-
static const bool CanUseSME2;
56-
static const bool CanUseSME;
57-
static const bool IsSMEAvailable;
58-
};
59-
60-
// Boolean condition to determine if we can use SME2
61-
// By default we should try for SME2 first before falling back to SME.
62-
inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
63-
// Boolean condition to determine if we can use SME
64-
inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();
65-
// Boolean condition to tell us if SME is enabled on this system
66-
inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME;
67-
68-
6954
//
7055
// Buffer packing routines.
7156
//

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ MLAS_FORCEINLINE void
158158
// of the ONNX Runtime source tree. OpenMP may or may not be enabled in this
159159
// configuration.
160160
//
161+
struct SMEInfo {
162+
static const bool CanUseSME2;
163+
static const bool CanUseSME;
164+
static const bool IsSMEAvailable;
165+
};
161166

162167
#if !defined(BUILD_MLAS_NO_ONNXRUNTIME)
163168
#include "core/platform/threadpool.h"
@@ -167,6 +172,16 @@ using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo;
167172

168173
#include "core/common/float16.h"
169174

175+
176+
177+
// Boolean condition to determine if we can use SME2
178+
// By default we should try for SME2 first before falling back to SME.
179+
inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
180+
// Boolean condition to determine if we can use SME
181+
inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();
182+
// Boolean condition to tell us if SME is enabled on this system
183+
inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME;
184+
170185
#else // BUILD_MLAS_NO_ONNXRUNTIME
171186

172187
class MLASCPUIDInfo
@@ -201,6 +216,10 @@ class MLASCPUIDInfo
201216

202217
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }
203218

219+
bool HasArm_SME() const { return has_arm_sme_; }
220+
221+
bool HasArm_SME2() const { return has_arm_sme2_; }
222+
204223
private:
205224
MLASCPUIDInfo();
206225

@@ -210,9 +229,19 @@ class MLASCPUIDInfo
210229
bool has_arm_sve_{false};
211230
bool has_arm_sve_i8mm_{false};
212231
bool has_arm_neon_bf16_{false};
232+
bool has_arm_sme_{false};
233+
bool has_arm_sme2_{false};
213234
};
214235
using MLAS_CPUIDINFO = MLASCPUIDInfo;
215236

237+
// Boolean condition to determine if we can use SME2
238+
// By default we should try for SME2 first before falling back to SME.
239+
inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
240+
// Boolean condition to determine if we can use SME
241+
inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();
242+
// Boolean condition to tell us if SME is enabled on this system
243+
inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME;
244+
216245
#if defined(MLAS_TARGET_ARM64)
217246
/**
218247
* @brief IDs for cpu microarchitectures.

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ Return Value:
601601
}
602602

603603
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
604-
if (ArmKleidiAI::SMEInfo::IsSMEAvailable) {
604+
if (SMEInfo::IsSMEAvailable) {
605605
this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch;
606606
this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize;
607607
this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB;

onnxruntime/core/mlas/lib/qgemm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ MlasDynamicQGemmBatch (
211211
) {
212212
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
213213
//No fallback and putting in guards
214-
if(ArmKleidiAI::SMEInfo::CanUseSME2){
214+
if(SMEInfo::CanUseSME2){
215215
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
216216
}
217217
#endif
@@ -336,7 +336,7 @@ MlasDynamicQgemmPackBSize(
336336
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
337337
//No fallback available
338338
//TODO: Insert Override
339-
if(ArmKleidiAI::SMEInfo::CanUseSME2){//Still require this since no override
339+
if(SMEInfo::CanUseSME2){//Still require this since no override
340340
bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K);
341341
}
342342
#endif
@@ -407,7 +407,7 @@ Return Value:
407407
~(BufferAlignment - 1);
408408
// If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
409409
// this use case. Concat both packed representations for later decision. This allows for cases later
410-
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
410+
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
411411
// for better performance
412412
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
413413
}
@@ -425,7 +425,7 @@ MlasDynamicQgemmPackB(
425425
{
426426
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
427427
//No fallback
428-
if (ArmKleidiAI::SMEInfo::CanUseSME2) {//Still require this since no override
428+
if (SMEInfo::CanUseSME2) {//Still require this since no override
429429
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
430430
}
431431
#endif

onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MlasDynamicQgemmTest {
2222
public:
2323
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
2424
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
25-
if (!ArmKleidiAI::SMEInfo::CanUseSME2) {
25+
if (!SMEInfo::CanUseSME2) {
2626
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test.";
2727
}
2828

0 commit comments

Comments
 (0)