-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Implement multithreading in qgemm_kleidi #26301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
d39b0e9
28be32c
332469e
9ea1898
8666505
7a96b3f
b38958b
0623beb
e54f509
9a4339e
4015816
b7b670f
143c882
1c707b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,10 +11,19 @@ | |||||||||||||||||||||||
| #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" | ||||||||||||||||||||||||
| #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" | ||||||||||||||||||||||||
| #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include "mlasi_kleidiai.h" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Thread-local reusable buffers to reduce allocation overhead across tiles. | ||||||||||||||||||||||||
| struct KaiTlsBuffersQgemm { | ||||||||||||||||||||||||
| std::vector<float> output_tile; | ||||||||||||||||||||||||
| std::vector<float> bias_zero; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| std::vector<std::byte> lhs_packed; | ||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||
| static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| //Matmul with float output of dynamic quantized A and symmetric quantized B. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| size_t | ||||||||||||||||||||||||
|
|
@@ -78,39 +87,147 @@ MLASCALL | |||||||||||||||||||||||
| ArmKleidiAI::MlasDynamicQGemmBatch( | ||||||||||||||||||||||||
| const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, | ||||||||||||||||||||||||
| const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, | ||||||||||||||||||||||||
| const size_t BatchN, | ||||||||||||||||||||||||
| const size_t BatchSize, | ||||||||||||||||||||||||
| MLAS_THREADPOOL* ThreadPool | ||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||
| for (auto b = BatchN; b > 0; --b,++DataParams) { | ||||||||||||||||||||||||
| auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); | ||||||||||||||||||||||||
| auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); | ||||||||||||||||||||||||
| auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||||||||||||||||||||||||
| : kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||||||||||||||||||||||||
| const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||||||||||||||||||||||||
| : kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||||||||||||||||||||||||
| const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||||||||||||||||||||||||
| : kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| //TODO enable multi-threading for lhs packing and matmul | ||||||||||||||||||||||||
| MLAS_UNREFERENCED_PARAMETER(ThreadPool); | ||||||||||||||||||||||||
| size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||||||||||||||||||||||||
| : kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||||||||||||||||||||||||
| size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||||||||||||||||||||||||
| : kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| //Dynamic Quantize A - lhs | ||||||||||||||||||||||||
| auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); | ||||||||||||||||||||||||
| std::byte* lhs = nullptr; | ||||||||||||||||||||||||
| std::unique_ptr<std::byte[]> fallback; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { | ||||||||||||||||||||||||
| lhs = static_cast<std::byte*>(DataParams->Workspace); | ||||||||||||||||||||||||
| if (Shape.M == 0 || Shape.N == 0) { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| if ((Shape.M < m_step || Shape.N < n_step) && !DataParams->PackedB) { | ||||||||||||||||||||||||
| // Fallback to MLAS | ||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no fallback implementation of onnxruntime/onnxruntime/core/mlas/lib/qgemm.cpp Lines 212 to 222 in 0f6cffc
if we get to this point, the computation should happen or (maybe less preferably) it should be a hard error.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will investigate the fallback case further and try to provide better implementation. ORT_ENFORCE(false, "ArmKleidiAI::MlasDynamicQGemmBatch(): unsupported small-shape case (M < m_step or N < n_step)");
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we instead implement @edgchen1's suggestion in the other PR: #26302 (comment) to have a universal check that can be used in all places to check if MLAS supports QGemm for that problem shape, platform, etc. ? Also since we have a check on the Just curious - what would happen if the M was < m_step ? Would there be a crash or would the perf be sub-optimal ? If so, we need to add a runtime check in the CPU kernel's Run() function which means we may need to perform pre-packing for both KAI and the "regular" path. See here. |
||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| //Dynamic Quantize A - lhs | ||||||||||||||||||||||||
| const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); | ||||||||||||||||||||||||
| std::byte* LhsPackedData = nullptr; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we just do the resizing directly instead of reserve + resize ?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, reserve() + resize() or using only resize() cases both end up with one allocation + one initialisation. But somehow there is a very very little performance difference in the case allocation and initialisation separated or done at once with resize(). (after: is the case reserve() calls removed and only resize() is used.) |
||||||||||||||||||||||||
| LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| //Per-batch table of lhs | ||||||||||||||||||||||||
| std::vector<const std::byte*> LhsBase(BatchSize); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| // B batches require no packing | ||||||||||||||||||||||||
| // We have already decided the matmul variant we are using, before having values for M,N,K | ||||||||||||||||||||||||
| MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| std::byte* lhs = nullptr; | ||||||||||||||||||||||||
| if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { | ||||||||||||||||||||||||
| lhs = static_cast<std::byte*>(DataParams[batch_idx].Workspace); | ||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||
| fallback = std::make_unique<std::byte[]>(lhs_size); | ||||||||||||||||||||||||
| lhs = fallback.get(); | ||||||||||||||||||||||||
| lhs = &(LhsPackedData[LhsPackedStride * batch_idx]); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, | ||||||||||||||||||||||||
| Shape.K*sizeof(float), lhs); | ||||||||||||||||||||||||
| kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs); | ||||||||||||||||||||||||
| LhsBase[batch_idx] = lhs; | ||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // tile iteration dimensions | ||||||||||||||||||||||||
| std::array<size_t, 3> dim; | ||||||||||||||||||||||||
| dim[0] = BatchSize; // B | ||||||||||||||||||||||||
| dim[1] = MlasDivRoundup(Shape.M, m_step); // M | ||||||||||||||||||||||||
| dim[2] = MlasDivRoundup(Shape.N, n_step); // N | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Minimize the kernel call count for the number of available threads | ||||||||||||||||||||||||
| auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // scale required tiles over available tile processors | ||||||||||||||||||||||||
| dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); | ||||||||||||||||||||||||
| dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( | ||||||||||||||||||||||||
| Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, | ||||||||||||||||||||||||
| DataParams->C, | ||||||||||||||||||||||||
| Shape.N * sizeof(float), | ||||||||||||||||||||||||
| sizeof(float), | ||||||||||||||||||||||||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||||||||||||||||||||||||
| // compute new step sizes | ||||||||||||||||||||||||
| m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step); | ||||||||||||||||||||||||
| n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // update tile iterations | ||||||||||||||||||||||||
| dim[1] = MlasDivRoundup(Shape.M, m_step); | ||||||||||||||||||||||||
| dim[2] = MlasDivRoundup(Shape.N, n_step); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // compute B,M,N index from iteration index | ||||||||||||||||||||||||
| ptrdiff_t BIdx = tid / (dim[1] * dim[2]); | ||||||||||||||||||||||||
| ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; | ||||||||||||||||||||||||
| ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Get rhs tile, B | ||||||||||||||||||||||||
| const size_t rhs_packed_offset = | ||||||||||||||||||||||||
| UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K) | ||||||||||||||||||||||||
| : kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const std::byte* B_base = reinterpret_cast<const std::byte*>(DataParams[BIdx].PackedB); | ||||||||||||||||||||||||
| auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Get lhs tile, A | ||||||||||||||||||||||||
| const size_t lhs_packed_offset = | ||||||||||||||||||||||||
| UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K) | ||||||||||||||||||||||||
| : kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const std::byte* A_base = LhsBase[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace; | ||||||||||||||||||||||||
| auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step; | ||||||||||||||||||||||||
| auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Get result tile, C | ||||||||||||||||||||||||
| auto CTile = reinterpret_cast<void*>( | ||||||||||||||||||||||||
| reinterpret_cast<std::byte*>(DataParams[BIdx].C) + | ||||||||||||||||||||||||
| MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) + | ||||||||||||||||||||||||
| NIdx * n_step * sizeof(float) | ||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| // Allocate temporary buffer for raw A*B result (TLS reusable buffer) | ||||||||||||||||||||||||
| { | ||||||||||||||||||||||||
| const size_t tile_elems = TileSizeM * TileSizeN; | ||||||||||||||||||||||||
| if (g_kai_tls_qgemm.output_tile.capacity() < tile_elems) { | ||||||||||||||||||||||||
| // reserve more memory if required | ||||||||||||||||||||||||
| g_kai_tls_qgemm.output_tile.reserve(tile_elems); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| // resize the tile to the required size (doesn't effect memory) | ||||||||||||||||||||||||
| g_kai_tls_qgemm.output_tile.resize(tile_elems); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| float* temp_tile = g_kai_tls_qgemm.output_tile.data(); | ||||||||||||||||||||||||
| std::fill_n(temp_tile, TileSizeM * TileSizeN, 0.0f); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if (UseSME2) { | ||||||||||||||||||||||||
| kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( | ||||||||||||||||||||||||
| TileSizeM, TileSizeN, Shape.K, ATile, BTile, | ||||||||||||||||||||||||
| temp_tile, | ||||||||||||||||||||||||
| TileSizeN * sizeof(float), | ||||||||||||||||||||||||
| sizeof(float), | ||||||||||||||||||||||||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| else { | ||||||||||||||||||||||||
| kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa( | ||||||||||||||||||||||||
| TileSizeM, TileSizeN, Shape.K, ATile, BTile, | ||||||||||||||||||||||||
| temp_tile, | ||||||||||||||||||||||||
| TileSizeN * sizeof(float), | ||||||||||||||||||||||||
| sizeof(float), | ||||||||||||||||||||||||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Final output tile pointer | ||||||||||||||||||||||||
| float* dst_tile = reinterpret_cast<float*>(CTile); | ||||||||||||||||||||||||
| std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||





There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this update in the KAI version from 1.10 to 1.15, can SME/SME2 detection be enabled on Windows too to leverage the kernels ?
https://github.com/microsoft/onnxruntime/pull/25187/files#r2223006773
https://github.com/microsoft/onnxruntime/pull/25760/files#r2325260570