diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d091995ccf508..c0ab948b41fff 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -5,6 +5,7 @@ set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas) set(MLAS_SRC_DIR ${MLAS_ROOT}/lib) set(MLAS_INC_DIR ${MLAS_ROOT}/inc) + # mlas_private_compile_definitions contains compile definitions that are private to onnxruntime_mlas and targets which # use internal MLAS headers like mlasi.h. set(mlas_private_compile_definitions) @@ -285,6 +286,15 @@ function(setup_kleidiai) list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai) set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE) + # If KLEIDIAI_DEBUG is enabled that implies both DEBUG and KERNEL messages. + if(onnxruntime_KLEIDIAI_DEBUG_LOGGING) + target_compile_definitions(onnxruntime_mlas PRIVATE KLEIDIAI_DEBUG=1) + target_compile_definitions(onnxruntime_mlas PRIVATE KLEIDIAI_KERNEL=1) + endif() + if(onnxruntime_KLEIDIAI_KERNEL_LOGGING) + target_compile_definitions(onnxruntime_mlas PRIVATE KLEIDIAI_KERNEL=1) + endif() + if (NOT onnxruntime_BUILD_SHARED_LIB) install(TARGETS kleidiai EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 36a6f70cc69d9..98e69836078c9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -222,8 +222,14 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. b_shape_ = tensor.Shape(); - // TO DO: handle b_shape_.NumDimensions() > 2 and all dimension values but the last two being 1. - if (!(b_shape_.NumDimensions() == 2 || (b_shape_.NumDimensions() == 3 && b_shape_[0] == 1))) { + if (b_shape_.NumDimensions() >= 2) { + for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) { + if (b_shape_[i] != 1) { + can_use_dynamic_quant_mlas_ = false; + break; + } + } + } else { can_use_dynamic_quant_mlas_ = false; } @@ -302,8 +308,10 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { int GetBIdx() const override { return IN_B; } private: + // Indicates when MlasDynamicQGemmBatch() can be used bool can_use_dynamic_quant_mlas_{false}; #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + // Indicates that the biases are a constant input and thus already quantized / packed bool dynamic_quant_mlas_bias_data_was_packed_{false}; #endif }; diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index 9eaf4902f536a..a5fb0667485d8 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -157,6 +157,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) { + KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on functional checks."); return false; } @@ -179,6 +180,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { MLAS_UNREFERENCED_PARAMETER(n_step); if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { + KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks."); return false; } return true; @@ -326,6 +328,10 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci auto m_idx = static_cast(tid) * m_step; auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme" + << " M=" << (m < (m_idx + m_step) ? m - m_idx : m_step) + << " k_chunk_count=" << (kh * kw) + << " k_chunk_length=" << ci); kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci, lhs_ptrs + m_idx * kh * kw, @@ -375,6 +381,8 @@ static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const bias_copy.resize(co, 0.0f); } + KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme" + << " N=" << co << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci << " rhs_stride_row=" << (co * sizeof(float))); kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get() ); @@ -599,6 +607,8 @@ static void ConvolveSme(const size_t co, //channels out MIdx * m_step * co * sizeof(float) + NIdx * n_step * sizeof(float)]; + KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" + << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci); kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float), -std::numeric_limits::max(), std::numeric_limits::max() diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index cd68a9d61a680..7e0aa4fe8f066 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -6,7 +6,8 @@ #pragma once -#include "../mlasi.h" +#include "mlasi.h" +#include // Fix to ensure compatibility with MSVC build #if defined(_MSC_VER) @@ -14,6 +15,40 @@ #else #define RESTRICT __restrict__ #endif + +// Logging macros. +#ifndef onnxruntime_KLEIDIAI_DEBUG_LOGGING +#define onnxruntime_KLEIDIAI_DEBUG_LOGGING 0 +#endif +#ifndef onnxruntime_KLEIDIAI_KERNEL_LOGGING +#define onnxruntime_KLEIDIAI_KERNEL_LOGGING 0 +#endif + +#if onnxruntime_KLEIDIAI_DEBUG_LOGGING || onnxruntime_KLEIDIAI_KERNEL_LOGGING +#define KLEIDIAI_LOG(tag, msg) \ + do { \ + std::cout << "[KLEIDIAI " << tag << "]: " << __FILE__ << " : " << __LINE__ << " : " << msg << std::endl; \ + } while(false) +#endif + +// General logging. "tag" is expected to qualify the type of message. +#if onnxruntime_KLEIDIAI_DEBUG_LOGGING + // General debug messages. + #define KLEIDIAI_DEBUG_LOG(msg) KLEIDIAI_LOG("DEBUG", msg) +#else + #define KLEIDIAI_DEBUG_LOG(msg) +#endif + +#if onnxruntime_KLEIDIAI_KERNEL_LOGGING + // Messages specifically written before a call to kai_run. + // Note: In cases where a kernel is called in multiple threads, for example MlasTrySimpleParallel, + // the output order can be inconsistient. The solution is to set the intra-node thread size to 1. + // If using onnxruntime_perf_test this is done with "--x 1". + #define KLEIDIAI_KERNEL_LOG(kernel_name) KLEIDIAI_LOG("KERNEL", kernel_name) +#else + #define KLEIDIAI_KERNEL_LOG(msg) +#endif + namespace ArmKleidiAI { // By default we should try for SME2 first before falling back to SME. inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp index fb38f2cef9bf6..1d682b372e2f5 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -29,6 +29,8 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize( auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); //regardless of kernel variant use neon packing variant + KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon Groups=1" + << " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr); return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); } @@ -102,9 +104,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch( lhs = fallback.get(); } + KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32" + << " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0"); kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, Shape.K*sizeof(float), lhs); + KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa"); kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, DataParams->C, diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index 435ff1fb10017..848372d71e314 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -55,25 +55,23 @@ Return Value: --*/ { if (TransA != CblasNoTrans || N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize returning 0 size. N=" << N << " K=" << K); return 0; } // // Compute the number of bytes required to hold the packed buffer. // size_t bytes = 0; - if (TransA == CblasNoTrans) { - switch (TransB) { - case CblasNoTrans: - bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); - break; - case CblasTrans: - bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K); - break; - default: - return 0; - } - } else { - return 0; + switch (TransB) { + case CblasNoTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + case CblasTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + default: + KLEIDIAI_DEBUG_LOG("MlasGemmPackBSize TransB is neither CblasNoTrans nor CblasTrans, returning 0."); + return 0; } return bytes; @@ -139,17 +137,23 @@ Return Value: switch (TransB) { case CblasNoTrans: + KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme Groups=1" + << " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr << " rhs_stride_row=" << ldb * sizeof(float)); kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); break; case CblasTrans: + KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme Groups=1" + << " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr << " rhs_stride_row=" << ldb * sizeof(float)); kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); break; default: + KLEIDIAI_DEBUG_LOG("MlasGemmPackB TransB is neither CblasNoTrans nor CblasTrans, falling back to MLAS."); return false; } return true; } else{ + KLEIDIAI_DEBUG_LOG("MlasGemmPackB TransA is CblasTrans, falling back to MLAS."); return false; } } @@ -263,6 +267,8 @@ Return Value: // We have already decided the matmul variant we are using, before having values for M,N,K MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_f32p2vlx1_f32_sme" + << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); }); } else { @@ -365,6 +371,8 @@ Return Value: std::fill_n(temp_tile, tile_elems, 0.0f); if (UseSME2) { + KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa" + << " M=" << TileSizeM << " << N=" << TileSizeN << " K=" << K); kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( TileSizeM, TileSizeN, @@ -374,6 +382,8 @@ Return Value: -std::numeric_limits::max(), std::numeric_limits::max() ); } else { + KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" + << " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K); kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( TileSizeM, TileSizeN, diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d4b32a66bd914..ad62cccbfb9c7 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -883,7 +883,7 @@ typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)( size_t N, size_t K); -typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)( +typedef void (MLASCALL MLAS_GEMM_PACK_B)( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, @@ -892,7 +892,7 @@ typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)( size_t ldb, void* PackedB); -typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)( +typedef bool (MLASCALL MLAS_GEMM_PACK_B_OVERRIDE)( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, @@ -1329,7 +1329,7 @@ struct MLAS_PLATFORM { // Mlas overrides initialisation MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; - MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr; + MLAS_GEMM_PACK_B_OVERRIDE* MlasGemmPackBOverride = nullptr; MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 4e9a0e27099dc..71e423ced3268 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -405,8 +405,10 @@ Return Value: const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); - //If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for - //this use case. Concat both packed representations for later decision. + // If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for + // this use case. Concat both packed representations for later decision. This allows for cases later + // where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization + // for better performance return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); }