-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Implement FP32 kleidiai Gemv #26302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement FP32 kleidiai Gemv #26302
Conversation
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
| kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); | ||
| kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetKleidiAIXUKernel() returns const&. do we need to make a copy here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); | |
| kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); | |
| const kai_matmul_clamp_f32_f32p_f32p_ukernel& sgemm_gemm = GetKleidiAISGemmUKernel(); | |
| const kai_matmul_clamp_f32_f32_f32p_ukernel& sgemm_gemv = GetKleidiAISGemvUKernel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to const in the latest push
onnxruntime/core/mlas/lib/qgemm.cpp
Outdated
| //No fallback and putting in guards | ||
| if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ | ||
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | ||
| if(ArmKleidiAI::SMEInfo::CanUseSME2){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are other places that need to be updated, like:
| if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { |
| if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { |
I might be missing some.
I think it would be worth making a helper function like MlasIsDynamicQGemmAvailable that has the appropriate checks and using that instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in the updated checks in various places like these in the latest push
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be worth making a helper function like
MlasIsDynamicQGemmAvailablethat has the appropriate checks and using that instead.
to clarify, this was the main suggestion.
a3f4f5b to
e8ab1b1
Compare
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
| void Test(size_t M, size_t N, size_t K, size_t BatchSize) { | ||
| // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. | ||
| if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { | ||
| if (!ArmKleidiAI::SMEInfo::CanUseSME2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I guess the Gtest skip comment needs corresponding update too.
onnxruntime/core/mlas/lib/qgemm.cpp
Outdated
| //No fallback and putting in guards | ||
| if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ | ||
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | ||
| if(ArmKleidiAI::SMEInfo::CanUseSME2){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess after merging #26301, the checks looking for SME2 will go away (i.e.) it can be run on both SME1 and SME2 then ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thats correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one change I've made in the latest push is to remove this structure from our kleidi code specifically and put it into mlasi.h removing the armkleidiai namespacing around it, seemed like a sensible place to put it given that other similar code exists in terms of cpu features
17d822c to
4afc95c
Compare
Signed-off-by: Jonathan Clohessy <[email protected]>
Signed-off-by: Jonathan Clohessy <[email protected]>
Signed-off-by: Jonathan Clohessy <[email protected]>
4afc95c to
1d9b7c8
Compare
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Description
Implementation of special sgemm path which uses GEMV kernels in cases where M or N are 1
Additionally this pr introduces the usage of a microkernel interface which utilizes typedef's provided by KleidiAI such that we can simplify the code and remove things such as ternary operations for SME1 vs SME2 kernels
Indicative Performance
In Lieu of any production models where gemv was a large contributor of the network. I opted to create a mini model to test which contains thousands of randomized matmul variants. With a distribution of GEMV cases throughout

Using onnxruntime perf test I was able to half the total inference time vs mlas with this model

More Benchmarks to come shortly