Skip to content

Commit 17d822c

Browse files
committed
Implement FP32 kleidiai Gemv
Signed-off-by: Jonathan Clohessy <[email protected]> Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent cf8476b commit 17d822c

File tree

7 files changed

+341
-70
lines changed

7 files changed

+341
-70
lines changed

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@
66

77
#include "kai_ukernel_interface.h"
88
#include "mlasi.h"
9+
#include "kleidiai/mlasi_kleidiai.h"
910

1011
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
1112
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
1213
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
1314
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"
1415

16+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
17+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
18+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
19+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
20+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
21+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
22+
1523
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod =
1624
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
1725
kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -64,6 +72,56 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp
6472
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
6573
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm};
6674

75+
const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme =
76+
{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
77+
kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
78+
kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
79+
kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
80+
kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
81+
kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
82+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
83+
kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
84+
kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
85+
kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla};
86+
87+
const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme2 =
88+
{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
89+
kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
90+
kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
91+
kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
92+
kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
93+
kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
94+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
95+
kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
96+
kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
97+
kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla};
98+
99+
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme =
100+
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
101+
kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
102+
kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
103+
kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
104+
kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
105+
kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
106+
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
107+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
108+
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
109+
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
110+
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa};
111+
112+
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 =
113+
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
114+
kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
115+
kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
116+
kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
117+
kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
118+
kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
119+
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
120+
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
121+
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
122+
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
123+
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa};
124+
67125
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() {
68126
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
69127
return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm;
@@ -79,3 +137,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() {
79137
return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod;
80138
}
81139
}
140+
141+
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() {
142+
if (ArmKleidiAI::SMEInfo::CanUseSME2) {
143+
return sgemm_gemm_sme2;
144+
} else {
145+
return sgemm_gemm_sme;
146+
}
147+
}
148+
149+
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() {
150+
if (ArmKleidiAI::SMEInfo::CanUseSME2) {
151+
return sgemm_gemv_sme2;
152+
} else {
153+
return sgemm_gemv_sme;
154+
}
155+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@
88

99
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"
1010

11+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h"
12+
13+
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"
14+
1115
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
1216
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();
17+
18+
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel();
19+
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,21 @@
5050
#endif
5151

5252
namespace ArmKleidiAI {
53+
54+
struct SMEInfo {
55+
static const bool CanUseSME2;
56+
static const bool CanUseSME;
57+
static const bool IsSMEAvailable;
58+
};
59+
60+
// Boolean condition to determine if we can use SME2
5361
// By default we should try for SME2 first before falling back to SME.
54-
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
62+
inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
63+
// Boolean condition to determine if we can use SME
64+
inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();
65+
// Boolean condition to tell us if SME is enabled on this system
66+
inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME;
67+
5568

5669
//
5770
// Buffer packing routines.
@@ -78,6 +91,19 @@ MlasGemmPackB(
7891
void* PackedB
7992
);
8093

94+
bool
95+
MLASCALL
96+
MlasFp32Gemv(
97+
CBLAS_TRANSPOSE TransA,
98+
CBLAS_TRANSPOSE TransB,
99+
size_t M,
100+
size_t N,
101+
size_t K,
102+
const MLAS_SGEMM_DATA_PARAMS* Data,
103+
size_t BatchSize
104+
);
105+
106+
81107
bool
82108
MLASCALL
83109
MlasGemmBatch(

0 commit comments

Comments
 (0)