Skip to content

Commit 61ff403

Browse files
authored
Add MlasIsDynamicQGemmAvailable() helper and use that in place of platform-specific checks (#26668)
### Description <!-- Describe your changes. --> Add `MlasIsDynamicQGemmAvailable()` helper function and use that in place of platform-specific checks. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Try to reduce platform-specific code.
1 parent 06f6f1c commit 61ff403

File tree

4 files changed

+45
-27
lines changed

4 files changed

+45
-27
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME2()
54
#include "core/common/narrow.h"
65
#include "core/common/safeint.h"
76
#include "core/mlas/inc/mlas.h"
@@ -213,9 +212,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
213212
}
214213
}
215214

216-
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
217-
// We check that here too before attempting to use them.
218-
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) {
215+
if (!MlasIsDynamicQGemmAvailable()) {
219216
can_use_dynamic_quant_mlas_ = false;
220217
}
221218

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ MlasGemm(
634634
{
635635
MlasGemmBatch(Shape, &DataParams, 1, ThreadPool);
636636
}
637+
637638
/**
638639
* @brief Parameters that define the shape of a dynamically quantized GEMM operation.
639640
*
@@ -646,6 +647,7 @@ struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS {
646647
size_t N = 0; /**< Column size of matrix B */
647648
size_t K = 0; /**< Column size of matrix A and Row size of matrix B */
648649
};
650+
649651
/**
650652
* @brief Parameters that define the data buffers and layout for a dynamic quant GEMM.
651653
*
@@ -680,6 +682,14 @@ MlasDynamicQGemm (
680682
MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool);
681683
}
682684

685+
/**
686+
* @brief Determines whether a dynamic quantized GEMM implementation is available on the current platform.
687+
*
688+
* MlasDynamicQGemm() and MlasDynamicQGemmBatch() should only be called if this function returns true.
689+
*/
690+
bool
691+
MLASCALL
692+
MlasIsDynamicQGemmAvailable();
683693

684694
//
685695
// Symmetric QGEMM has limited buffer overrun.

onnxruntime/core/mlas/lib/qgemm.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ MlasGemmBatch(
201201
});
202202
}
203203

204+
bool
205+
MLASCALL
206+
MlasIsDynamicQGemmAvailable()
207+
{
208+
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
209+
return ArmKleidiAI::UseSME2;
210+
#else
211+
return false;
212+
#endif
213+
}
214+
204215
void
205216
MLASCALL
206217
MlasDynamicQGemmBatch (
@@ -209,11 +220,11 @@ MlasDynamicQGemmBatch (
209220
const size_t BatchN,
210221
MLAS_THREADPOOL* ThreadPool
211222
) {
223+
assert(MlasIsDynamicQGemmAvailable());
224+
212225
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
213-
//No fallback and putting in guards. This implementation is SME2 specific.
214-
if(ArmKleidiAI::UseSME2){
215-
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
216-
}
226+
//No fallback
227+
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
217228
#endif
218229

219230
MLAS_UNREFERENCED_PARAMETER(Shape);
@@ -332,13 +343,13 @@ MlasDynamicQgemmPackBSize(
332343
size_t K
333344
)
334345
{
346+
assert(MlasIsDynamicQGemmAvailable());
347+
335348
size_t bytes = 0;
336349
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
337350
//No fallback available
338351
//TODO: Insert Override
339-
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
340-
bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K);
341-
}
352+
bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K);
342353
#endif
343354

344355
MLAS_UNREFERENCED_PARAMETER(N);
@@ -405,11 +416,15 @@ Return Value:
405416
const size_t BufferAlignment = MlasGetPreferredBufferAlignment();
406417
const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) &
407418
~(BufferAlignment - 1);
408-
// If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
419+
// If this gemm B argument is used in a dynamically quantized gemm operation we can optimize for
409420
// this use case. Concat both packed representations for later decision. This allows for cases later
410-
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
421+
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
411422
// for better performance
412-
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
423+
if (MlasIsDynamicQGemmAvailable()) {
424+
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
425+
} else {
426+
return AlignedBytesRequired;
427+
}
413428
}
414429

415430
void
@@ -423,11 +438,11 @@ MlasDynamicQgemmPackB(
423438
void* PackedB
424439
)
425440
{
441+
assert(MlasIsDynamicQGemmAvailable());
442+
426443
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
427444
//No fallback
428-
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
429-
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
430-
}
445+
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
431446
#endif
432447

433448
MLAS_UNREFERENCED_PARAMETER(N);

onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
// SPDX-License-Identifier: MIT
55
//
66

7-
// Currently this test only applies to KleidiAI Guard against it running in any other situation
8-
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
9-
7+
#include "mlas.h"
108
#include "test_util.h"
11-
#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO
129

1310
class MlasDynamicQgemmTest {
1411
private:
@@ -20,11 +17,6 @@ class MlasDynamicQgemmTest {
2017

2118
public:
2219
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
23-
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
24-
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
25-
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME2 but it was not detected. Skipping test.";
26-
}
27-
2820
// Setup buffers for holding various data
2921

3022
float* A = buffer_a.GetBuffer(M * K * BatchSize);
@@ -167,6 +159,10 @@ class DynamicQgemmExecuteTest : public MlasTestFixture<MlasDynamicQgemmTest> {
167159
};
168160

169161
static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
162+
// Only register tests if MlasDynamicQGemmBatch() has an implementation available.
163+
if (!MlasIsDynamicQGemmAvailable()) {
164+
return size_t{0};
165+
}
166+
170167
return DynamicQgemmExecuteTest::RegisterAll(is_short_execute);
171168
});
172-
#endif

0 commit comments

Comments
 (0)