Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7ca7687
Add logging macros for KleidiAI and fix todos from previous PR
orlmon01 Sep 24, 2025
6413140
Update onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matm…
orlmon01 Oct 2, 2025
6f14c5f
Merge branch 'main' into main
orlmon01 Oct 3, 2025
8bdfa00
Update dynamic_quantize_matmul.cc
orlmon01 Oct 8, 2025
d20eea2
Update CMakeLists.txt
orlmon01 Oct 9, 2025
a444642
Update dynamic_quantize_matmul.cc
orlmon01 Oct 9, 2025
218902e
Update dynamic_quantize_matmul.cc
orlmon01 Oct 12, 2025
dd50726
Update dynamic_quantize_matmul.cc
orlmon01 Oct 12, 2025
153546c
Change target compile defs for logging macros
orlmon01 Oct 12, 2025
c74d554
Update dynamic_quantize_matmul.cc
orlmon01 Oct 13, 2025
4ed4bd0
Merge branch 'main' into main
orlmon01 Oct 13, 2025
586b94f
Update dynamic_quantize_matmul.cc
orlmon01 Oct 14, 2025
6264d19
Merge branch 'microsoft:main' into main
orlmon01 Oct 16, 2025
817cb3a
Update build.py
orlmon01 Oct 16, 2025
12ba586
Update qgemm.cpp
orlmon01 Oct 20, 2025
bdc4734
Merge branch 'microsoft:main' into main
orlmon01 Oct 20, 2025
d334f63
Merge branch 'microsoft:main' into main
orlmon01 Oct 22, 2025
864a566
Merge branch 'microsoft:main' into main
orlmon01 Oct 22, 2025
3c27bba
Update onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
orlmon01 Oct 22, 2025
01c4202
Update cmake/CMakeLists.txt
orlmon01 Oct 22, 2025
73c95f4
Merge branch 'main' into main
orlmon01 Oct 28, 2025
f9b1e4b
* Move onnxruntime_KLEIDIAI_DEBUG_LOGGING from CMakeLists.txt to onnx…
orlmon01 Nov 3, 2025
32d1e0c
Merge branch 'microsoft:main' into main
orlmon01 Nov 3, 2025
9a50920
* Tidy up the checks around CblasNoTrans in sgemm_kleidiai.cpp
orlmon01 Nov 4, 2025
b5889cb
Merge branch 'microsoft:main' into main
orlmon01 Nov 4, 2025
cc454f6
KLEIDIAI_LOG should be defined for both kernel and debug logging
orlmon01 Nov 4, 2025
dc93b54
Merge branch 'microsoft:main' into main
orlmon01 Nov 4, 2025
b6fb2bf
Tidy up the formatting of the KLEDIAI logging messages
orlmon01 Nov 4, 2025
e1dc6e5
Merge branch 'microsoft:main' into main
orlmon01 Nov 4, 2025
01631a7
Merge branch 'microsoft:main' into main
orlmon01 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -678,25 +678,14 @@ if(onnxruntime_USE_SVE)
endif()
endif()

if (onnxruntime_USE_KLEIDIAI AND (
(onnxruntime_target_platform STREQUAL "aarch64") OR
(onnxruntime_target_platform STREQUAL "ARM64") OR
(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")))

# TODO Add checks for MSVC Compilation
if(NOT MSVC)
check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD)
check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM)
if (NOT HAS_ARM64_DOTPROD)
message(FATAL_ERROR "The compiler doesn't support dotprod")
endif()
if (NOT HAS_ARM64_I8MM)
message(FATAL_ERROR "The compiler doesn't support i8mm")
endif()
else()
message(STATUS "Skipping -march= checks on MSVC (not supported), assuming dotprod/i8mm support manually.")
set(HAS_ARM64_DOTPROD TRUE)
set(HAS_ARM64_I8MM TRUE)
if(onnxruntime_USE_KLEIDIAI)
# If KLEIDIAI_DEBUG is enabled that implies both DEBUG and KERNEL messages.
if(KLEIDIAI_DEBUG)
add_definitions(-DKLEIDIAI_DEBUG=1)
add_definitions(-DKLEIDIAI_KERNEL=1)
endif()
if(KLEIDIAI_KERNEL)
add_definitions(-DKLEIDIAI_KERNEL=1)
endif()
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
b_shape_ = tensor.Shape();
// TO DO: handle b_shape_.NumDimensions() > 2 and all dimension values but the last two being 1.
if (!(b_shape_.NumDimensions() == 2 || (b_shape_.NumDimensions() == 3 && b_shape_[0] == 1))) {
can_use_dynamic_quant_mlas_ = false;
if (b_shape_.NumDimensions() > 2) {
for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) {
if (b_shape_[i] != 1) {
can_use_dynamic_quant_mlas_ = false;
break;
}
}
}

// Can we use the mlas dynamic Q gemm interface supported with float output ?
Expand Down Expand Up @@ -302,8 +306,10 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
int GetBIdx() const override { return IN_B; }

private:
//Indicates when dynamic quantization is available so we can use the KleidiAI dynamic qgemm kernel
bool can_use_dynamic_quant_mlas_{false};
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//Indicates that the biases are a constant input and thus already quantized / packed
bool dynamic_quant_mlas_bias_data_was_packed_{false};
#endif
};
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
ComputeConvOutSize(Parameters->InputShape[1],
ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]),
Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on functional checks.");
return false;
}

Expand All @@ -179,6 +180,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
MLAS_UNREFERENCED_PARAMETER(n_step);

if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks.");
return false;
}
return true;
Expand Down Expand Up @@ -326,6 +328,10 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
auto m_idx = static_cast<size_t>(tid) * m_step;
auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci);

KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme"
<< " M=" << (m < (m_idx + m_step) ? m - m_idx : m_step)
<< " k_chunk_count=" << (kh * kw)
<< " k_chunk_length=" << ci);
kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(
m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci,
lhs_ptrs + m_idx * kh * kw,
Expand Down Expand Up @@ -375,6 +381,8 @@ static std::shared_ptr<std::byte[]> RhsPackWeightsBiasSme(const size_t co, const
bias_copy.resize(co, 0.0f);
}

KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme"
<< " N=" << co << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci << " rhs_stride_row=" << (co * sizeof(float)));
kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get()
);
Expand Down Expand Up @@ -599,6 +607,8 @@ static void ConvolveSme(const size_t co, //channels out
MIdx * m_step * co * sizeof(float) +
NIdx * n_step * sizeof(float)];

KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa"
<< " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,43 @@
#pragma once

#include "mlasi.h"
#include <iostream>

// Fix to ensure compatibility with MSVC build
#if defined(_MSC_VER)
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif

// Logging macros.
#ifndef KLEIDIAI_DEBUG
#define KLEIDIAI_DEBUG 0
#endif
#ifndef KLEIDIAI_KERNEL
#define KLEIDIAI_KERNEL 0
#endif

// General logging. "tag" is expected to quality the type of message.
#define KLEIDIAI_LOG(tag, msg) std::cout << "[KLEIDIAI " << tag << "]: " << msg << std::endl;

#if KLEIDIAI_DEBUG
// General debug messages.
#define KLEIDIAI_DEBUG_LOG(msg) KLEIDIAI_LOG("DEBUG", msg)
#else
#define KLEIDIAI_DEBUG_LOG(msg)
#endif

#if KLEIDIAI_KERNEL
// Messages specifically written before a call to kai_run.
// Note: In cases where a kernel is called in multiple threads, for example MlasTrySimpleParallel,
// the output order can be inconsistient. The solution is to set the intra-node thread size to 1.
// If using onnxruntime_perf_test this is done with "--x 1".
#define KLEIDIAI_KERNEL_LOG(kernel_name) KLEIDIAI_LOG("KERNEL", kernel_name)
#else
#define KLEIDIAI_KERNEL_LOG(msg)
#endif

namespace ArmKleidiAI {
// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize(
auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();

//regardless of kernel variant use neon packing variant
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon Groups=1"
<< " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr);
return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
}

Expand Down Expand Up @@ -102,9 +104,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
lhs = fallback.get();
}

KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32"
<< " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0");
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A,
Shape.K*sizeof(float), lhs);

KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa");
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB,
DataParams->C,
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Return Value:
--*/
{
if (TransA != CblasNoTrans || N == 0 || K == 0) {
KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize returning 0 size. N=" << N << " K=" << K);
return 0;
}
//
Expand All @@ -61,9 +62,11 @@ Return Value:
bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K);
break;
default:
KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize TransB is neither CblasNoTrans nor CblasTrans, returning 0.");
return 0;
}
} else {
KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize TransA is CblasTrans, returning 0.");
return 0;
}

Expand Down Expand Up @@ -130,17 +133,23 @@ Return Value:

switch (TransB) {
case CblasNoTrans:
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme Groups=1"
<< " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr << " rhs_stride_row=" << ldb * sizeof(float));
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr);
break;
case CblasTrans:
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme Groups=1"
<< " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr << " rhs_stride_row=" << ldb * sizeof(float));
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr);
break;
default:
KLEIDIAI_DEBUG_LOG("MlasGemmPackB TransB is neither CblasNoTrans nor CblasTrans, falling back to MLAS.");
return false;
}
return true;
}
else{
KLEIDIAI_DEBUG_LOG("MlasGemmPackB TransA is CblasTrans, falling back to MLAS.");
return false;
}
}
Expand Down Expand Up @@ -247,6 +256,8 @@ Return Value:
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_f32p2vlx1_f32_sme"
<< " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr);
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
KaiPackedData[batch_idx].B = Data[batch_idx].B;
Expand Down Expand Up @@ -341,6 +352,8 @@ Return Value:
float* temp_tile = OutputTile.data();

if (UseSME2) {
KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"
<< " M=" << TileSizeM << " << N=" << TileSizeN << " K=" << K);
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
TileSizeM,
TileSizeN,
Expand All @@ -350,6 +363,8 @@ Return Value:
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
} else {
KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa"
<< " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K);
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
TileSizeM,
TileSizeN,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)(
size_t N,
size_t K);

typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)(
typedef void (MLASCALL MLAS_GEMM_PACK_B)(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t N,
Expand All @@ -889,7 +889,7 @@ typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)(
size_t ldb,
void* PackedB);

typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)(
typedef bool (MLASCALL MLAS_GEMM_PACK_B_OVERRIDE)(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t N,
Expand Down Expand Up @@ -1315,7 +1315,7 @@ struct MLAS_PLATFORM {
// Mlas overrides initialisation
MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr;
MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr;
MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr;
MLAS_GEMM_PACK_B_OVERRIDE* MlasGemmPackBOverride = nullptr;
MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr;
MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr;

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ Return Value:
const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) &
~(BufferAlignment - 1);
//If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
//this use case. Concat both packed representations for later decision.
//this use case. Concat both packed representations for later decision. This allows for cases later
//where can_use_dynamic_mlas is false we still have the prepack at the cost of some memory otherwise
//we can use the qgemm quantization for better performance
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
}

Expand Down
1 change: 0 additions & 1 deletion tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,6 @@ def generate_build_tree(
# * Finally enable if platform.machine contains "arm64" and not a WebAssembly build. This should cover the following cases:
# * Linux on Arm
# * MacOs (case must be ignored)
# * TODO Delegate responsibility for Onnxruntime_USE_KLEIDIAI = ON to CMake logic
if not args.no_kleidiai:
if (
(args.android and "arm64" in args.android_abi.lower())
Expand Down
Loading