Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7ca7687
Add logging macros for KleidiAI and fix todos from previous PR
orlmon01 Sep 24, 2025
6413140
Update onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matm…
orlmon01 Oct 2, 2025
6f14c5f
Merge branch 'main' into main
orlmon01 Oct 3, 2025
8bdfa00
Update dynamic_quantize_matmul.cc
orlmon01 Oct 8, 2025
d20eea2
Update CMakeLists.txt
orlmon01 Oct 9, 2025
a444642
Update dynamic_quantize_matmul.cc
orlmon01 Oct 9, 2025
218902e
Update dynamic_quantize_matmul.cc
orlmon01 Oct 12, 2025
dd50726
Update dynamic_quantize_matmul.cc
orlmon01 Oct 12, 2025
153546c
Change target compile defs for logging macros
orlmon01 Oct 12, 2025
c74d554
Update dynamic_quantize_matmul.cc
orlmon01 Oct 13, 2025
4ed4bd0
Merge branch 'main' into main
orlmon01 Oct 13, 2025
586b94f
Update dynamic_quantize_matmul.cc
orlmon01 Oct 14, 2025
6264d19
Merge branch 'microsoft:main' into main
orlmon01 Oct 16, 2025
817cb3a
Update build.py
orlmon01 Oct 16, 2025
12ba586
Update qgemm.cpp
orlmon01 Oct 20, 2025
bdc4734
Merge branch 'microsoft:main' into main
orlmon01 Oct 20, 2025
d334f63
Merge branch 'microsoft:main' into main
orlmon01 Oct 22, 2025
864a566
Merge branch 'microsoft:main' into main
orlmon01 Oct 22, 2025
3c27bba
Update onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
orlmon01 Oct 22, 2025
01c4202
Update cmake/CMakeLists.txt
orlmon01 Oct 22, 2025
73c95f4
Merge branch 'main' into main
orlmon01 Oct 28, 2025
f9b1e4b
* Move onnxruntime_KLEIDIAI_DEBUG_LOGGING from CMakeLists.txt to onnx…
orlmon01 Nov 3, 2025
32d1e0c
Merge branch 'microsoft:main' into main
orlmon01 Nov 3, 2025
9a50920
* Tidy up the checks around CblasNoTrans in sgemm_kleidiai.cpp
orlmon01 Nov 4, 2025
b5889cb
Merge branch 'microsoft:main' into main
orlmon01 Nov 4, 2025
cc454f6
KLEIDIAI_LOG should be defined for both kernel and debug logging
orlmon01 Nov 4, 2025
dc93b54
Merge branch 'microsoft:main' into main
orlmon01 Nov 4, 2025
b6fb2bf
Tidy up the formatting of the KLEDIAI logging messages
orlmon01 Nov 4, 2025
e1dc6e5
Merge branch 'microsoft:main' into main
orlmon01 Nov 4, 2025
01631a7
Merge branch 'microsoft:main' into main
orlmon01 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas)
set(MLAS_SRC_DIR ${MLAS_ROOT}/lib)
set(MLAS_INC_DIR ${MLAS_ROOT}/inc)


# mlas_private_compile_definitions contains compile definitions that are private to onnxruntime_mlas and targets which
# use internal MLAS headers like mlasi.h.
set(mlas_private_compile_definitions)
Expand Down Expand Up @@ -285,6 +286,15 @@ function(setup_kleidiai)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai)
set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE)

# If KLEIDIAI_DEBUG is enabled that implies both DEBUG and KERNEL messages.
if(onnxruntime_KLEIDIAI_DEBUG_LOGGING)
target_compile_definitions(onnxruntime_mlas PRIVATE KLEIDIAI_DEBUG=1)
target_compile_definitions(onnxruntime_mlas PRIVATE KLEIDIAI_KERNEL=1)
endif()
if(onnxruntime_KLEIDIAI_KERNEL_LOGGING)
target_compile_definitions(onnxruntime_mlas PRIVATE KLEIDIAI_KERNEL=1)
endif()

if (NOT onnxruntime_BUILD_SHARED_LIB)
install(TARGETS kleidiai EXPORT ${PROJECT_NAME}Targets
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,14 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
b_shape_ = tensor.Shape();
// TO DO: handle b_shape_.NumDimensions() > 2 and all dimension values but the last two being 1.
if (!(b_shape_.NumDimensions() == 2 || (b_shape_.NumDimensions() == 3 && b_shape_[0] == 1))) {
if (b_shape_.NumDimensions() >= 2) {
for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) {
if (b_shape_[i] != 1) {
can_use_dynamic_quant_mlas_ = false;
break;
}
}
} else {
can_use_dynamic_quant_mlas_ = false;
}

Expand Down Expand Up @@ -302,8 +308,10 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
int GetBIdx() const override { return IN_B; }

private:
// Indicates when MlasDynamicQGemmBatch() can be used
bool can_use_dynamic_quant_mlas_{false};
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
// Indicates that the biases are a constant input and thus already quantized / packed
bool dynamic_quant_mlas_bias_data_was_packed_{false};
#endif
};
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
ComputeConvOutSize(Parameters->InputShape[1],
ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]),
Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on functional checks.");
return false;
}

Expand All @@ -179,6 +180,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
MLAS_UNREFERENCED_PARAMETER(n_step);

if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks.");
return false;
}
return true;
Expand Down Expand Up @@ -326,6 +328,10 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
auto m_idx = static_cast<size_t>(tid) * m_step;
auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci);

KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme"
<< " M=" << (m < (m_idx + m_step) ? m - m_idx : m_step)
<< " k_chunk_count=" << (kh * kw)
<< " k_chunk_length=" << ci);
kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(
m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci,
lhs_ptrs + m_idx * kh * kw,
Expand Down Expand Up @@ -375,6 +381,8 @@ static std::shared_ptr<std::byte[]> RhsPackWeightsBiasSme(const size_t co, const
bias_copy.resize(co, 0.0f);
}

KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme"
<< " N=" << co << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci << " rhs_stride_row=" << (co * sizeof(float)));
kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get()
);
Expand Down Expand Up @@ -599,6 +607,8 @@ static void ConvolveSme(const size_t co, //channels out
MIdx * m_step * co * sizeof(float) +
NIdx * n_step * sizeof(float)];

KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa"
<< " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
Expand Down
37 changes: 36 additions & 1 deletion onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,49 @@

#pragma once

#include "../mlasi.h"
#include "mlasi.h"
#include <iostream>

// Fix to ensure compatibility with MSVC build
#if defined(_MSC_VER)
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif

// Logging macros.
#ifndef onnxruntime_KLEIDIAI_DEBUG_LOGGING
#define onnxruntime_KLEIDIAI_DEBUG_LOGGING 0
#endif
#ifndef onnxruntime_KLEIDIAI_KERNEL_LOGGING
#define onnxruntime_KLEIDIAI_KERNEL_LOGGING 0
#endif

#if onnxruntime_KLEIDIAI_DEBUG_LOGGING || onnxruntime_KLEIDIAI_KERNEL_LOGGING
#define KLEIDIAI_LOG(tag, msg) \
do { \
std::cout << __FILE__ << __LINE__ << "[KLEIDIAI " << tag << "]: " << msg << std::endl; \
} while(false)
#endif

// General logging. "tag" is expected to qualify the type of message.
#if onnxruntime_KLEIDIAI_DEBUG_LOGGING
// General debug messages.
#define KLEIDIAI_DEBUG_LOG(msg) KLEIDIAI_LOG("DEBUG", msg)
#else
#define KLEIDIAI_DEBUG_LOG(msg)
#endif

#if onnxruntime_KLEIDIAI_KERNEL_LOGGING
// Messages specifically written before a call to kai_run.
// Note: In cases where a kernel is called in multiple threads, for example MlasTrySimpleParallel,
// the output order can be inconsistient. The solution is to set the intra-node thread size to 1.
// If using onnxruntime_perf_test this is done with "--x 1".
#define KLEIDIAI_KERNEL_LOG(kernel_name) KLEIDIAI_LOG("KERNEL", kernel_name)
#else
#define KLEIDIAI_KERNEL_LOG(msg)
#endif

namespace ArmKleidiAI {
// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize(
auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();

//regardless of kernel variant use neon packing variant
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon Groups=1"
<< " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr);
return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
}

Expand Down Expand Up @@ -102,9 +104,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
lhs = fallback.get();
}

KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32"
<< " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0");
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A,
Shape.K*sizeof(float), lhs);

KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa");
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB,
DataParams->C,
Expand Down
36 changes: 23 additions & 13 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,23 @@ Return Value:
--*/
{
if (TransA != CblasNoTrans || N == 0 || K == 0) {
KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize returning 0 size. N=" << N << " K=" << K);
return 0;
}
//
// Compute the number of bytes required to hold the packed buffer.
//
size_t bytes = 0;
if (TransA == CblasNoTrans) {
switch (TransB) {
case CblasNoTrans:
bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K);
break;
case CblasTrans:
bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K);
break;
default:
return 0;
}
} else {
return 0;
switch (TransB) {
case CblasNoTrans:
bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K);
break;
case CblasTrans:
bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K);
break;
default:
KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize TransB is neither CblasNoTrans nor CblasTrans, returning 0.");
return 0;
}

return bytes;
Expand Down Expand Up @@ -139,17 +137,23 @@ Return Value:

switch (TransB) {
case CblasNoTrans:
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme Groups=1"
<< " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr << " rhs_stride_row=" << ldb * sizeof(float));
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr);
break;
case CblasTrans:
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme Groups=1"
<< " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr << " rhs_stride_row=" << ldb * sizeof(float));
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr);
break;
default:
KLEIDIAI_DEBUG_LOG("MlasGemmPackB TransB is neither CblasNoTrans nor CblasTrans, falling back to MLAS.");
return false;
}
return true;
}
else{
KLEIDIAI_DEBUG_LOG("MlasGemmPackB TransA is CblasTrans, falling back to MLAS.");
return false;
}
}
Expand Down Expand Up @@ -263,6 +267,8 @@ Return Value:
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_f32p2vlx1_f32_sme"
<< " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr);
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
});
} else {
Expand Down Expand Up @@ -365,6 +371,8 @@ Return Value:
std::fill_n(temp_tile, tile_elems, 0.0f);

if (UseSME2) {
KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"
<< " M=" << TileSizeM << " << N=" << TileSizeN << " K=" << K);
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
TileSizeM,
TileSizeN,
Expand All @@ -374,6 +382,8 @@ Return Value:
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
} else {
KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa"
<< " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K);
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
TileSizeM,
TileSizeN,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)(
size_t N,
size_t K);

typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)(
typedef void (MLASCALL MLAS_GEMM_PACK_B)(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t N,
Expand All @@ -892,7 +892,7 @@ typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)(
size_t ldb,
void* PackedB);

typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)(
typedef bool (MLASCALL MLAS_GEMM_PACK_B_OVERRIDE)(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t N,
Expand Down Expand Up @@ -1329,7 +1329,7 @@ struct MLAS_PLATFORM {
// Mlas overrides initialisation
MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr;
MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr;
MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr;
MLAS_GEMM_PACK_B_OVERRIDE* MlasGemmPackBOverride = nullptr;
MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr;
MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr;

Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,10 @@ Return Value:
const size_t BufferAlignment = MlasGetPreferredBufferAlignment();
const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) &
~(BufferAlignment - 1);
//If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
//this use case. Concat both packed representations for later decision.
// If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
// this use case. Concat both packed representations for later decision. This allows for cases later
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
// for better performance
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
}

Expand Down
Loading