Skip to content

Commit a16573f

Browse files
Fix safeint-related errors in MLAS
1. Add safeint header file application path to MLAS 2. Fix syntax errors in sqnbitgem and lasx Co-authored-by: lixing [email protected]
1 parent 34b7558 commit a16573f

File tree

3 files changed

+82
-82
lines changed

3 files changed

+82
-82
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ endif()
828828
endif()
829829

830830
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
831-
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
831+
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR} ${safeint_SOURCE_DIR})
832832
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
833833

834834
target_compile_definitions(${mlas_target} PRIVATE ${mlas_private_compile_definitions})

onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ QNBitGemmPackQuantBDataSize_Lasx(
4848
BlkSumSize += SafeInt<size_t>(BlkSumAlignment) - 1;
4949

5050
PackedQuantBDataSize += ScaleSize + BlkSumSize;
51-
return PackedQuantBDataSize.Value();
51+
return static_cast<size_t>(PackedQuantBDataSize);
5252
} else {
5353
SafeInt<size_t> PackedQuantBDataSize = SafeInt<size_t>(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
54-
return PackedQuantBDataSize.Value();
54+
return static_cast<size_t>(PackedQuantBDataSize);
5555
}
5656
}
5757

@@ -73,7 +73,7 @@ SQ4BitGemmPackQuantBData_Lasx(
7373

7474
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
7575
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
76-
const SafeInt<size_t> Iterations = SafeInt<size_t>(N) * BlockCountK; // one iteration per block
76+
const size_t Iterations = SafeInt<size_t>(N) * BlockCountK; // one iteration per block
7777

7878
size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64);
7979

@@ -105,14 +105,14 @@ SQ4BitGemmPackQuantBData_Lasx(
105105
//
106106

107107
MlasTrySimpleParallel(
108-
ThreadPool, Iterations.Value(),
108+
ThreadPool, Iterations,
109109
[&](ptrdiff_t tid) {
110110
const size_t n = tid / BlockCountK;
111111
const size_t k_blk = tid % BlockCountK;
112112

113-
const SafeInt<size_t> data_offset = SafeInt<size_t>(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
114-
const std::byte* QuantBData = QuantBDataBegin + data_offset.Value();
115-
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset.Value();
113+
const size_t data_offset = SafeInt<size_t>(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
114+
const std::byte* QuantBData = QuantBDataBegin + data_offset;
115+
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
116116

117117
for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) {
118118
for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) {
@@ -163,8 +163,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum_Lasx(
163163
}
164164

165165
if (QuantBScaleBegin) {
166-
SafeInt<size_t> offset = SafeInt<size_t>(N) * BlockCountK;
167-
std::copy(QuantBScaleBegin, QuantBScaleBegin + offset.Value(), packed_quant_b.PackedQuantBScale);
166+
size_t offset = SafeInt<size_t>(N) * BlockCountK;
167+
std::copy(QuantBScaleBegin, QuantBScaleBegin + offset, packed_quant_b.PackedQuantBScale);
168168
}
169169

170170
if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) {
@@ -272,14 +272,14 @@ ComputeDotProducts_BlkLen32Plus_CompFp32_lasx(
272272

273273
float scale_v[NCols];
274274
UnrolledLoop<NCols>([&](size_t i) {
275-
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
276-
scale_v[i] = *(s + scale_offset.Value());
275+
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
276+
scale_v[i] = *(s + scale_offset);
277277
});
278278

279279
std::byte* b_blk_data_col_ptr[NCols];
280280
UnrolledLoop<NCols>([&](size_t i) {
281-
SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * i;
282-
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value());
281+
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * i;
282+
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset);
283283
});
284284

285285
// not ready for "Manual conversion to float" in neon yet.
@@ -427,14 +427,14 @@ ComputeDotProducts_BlkLen16_CompFp32_lasx(
427427

428428
float scale_v[NCols];
429429
UnrolledLoop<NCols>([&](size_t i) {
430-
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
431-
scale_v[i] = *(s + scale_offset.Value());
430+
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
431+
scale_v[i] = *(s + scale_offset);
432432
});
433433

434434
std::byte* b_blk_data_col_ptr[NCols];
435435
UnrolledLoop<NCols>([&](size_t i) {
436-
SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * i;
437-
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value());
436+
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * i;
437+
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset);
438438
});
439439

440440
if constexpr (HasZeroPoint) {
@@ -551,7 +551,7 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx(
551551

552552
float* SumPtr = CRowPtr;
553553

554-
int64_t nblk = <int64_t>(CountN) - NCols4;
554+
int64_t nblk = static_cast<int64_t>(CountN - NCols4);
555555
while (nblk >= 0) {
556556
ComputeDotProducts_BlkLen16_CompFp32_lasx<NCols4, HasZeroPoint>(
557557
BlkLen16,
@@ -560,13 +560,13 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx(
560560
BiasPtr
561561
);
562562

563-
SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
564-
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
565-
QuantBDataColPtr += data_offset.Value();
566-
QuantBScaleColPtr += scale_offset.Value();
563+
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
564+
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
565+
QuantBDataColPtr += data_offset;
566+
QuantBScaleColPtr += scale_offset;
567567
if constexpr (HasZeroPoint) {
568-
SafeInt<size_t> zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
569-
QuantBZeroPointColPtr += zeropoint_offset.Value();
568+
size_t zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
569+
QuantBZeroPointColPtr += zeropoint_offset;
570570
}
571571

572572
BiasPtr += BiasPtr != nullptr ? NCols4 : 0;
@@ -650,13 +650,13 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx(
650650
);
651651
}
652652

653-
SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
654-
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
655-
QuantBDataColPtr += data_offset.Value();
656-
QuantBScaleColPtr += scale_offset.Value();
653+
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
654+
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
655+
QuantBDataColPtr += data_offset;
656+
QuantBScaleColPtr += scale_offset;
657657
if constexpr (HasZeroPoint) {
658-
SafeInt<size_t> zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
659-
QuantBZeroPointColPtr += zeropoint_offset.Value();
658+
size_t zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
659+
QuantBZeroPointColPtr += zeropoint_offset;
660660
}
661661

662662
BiasPtr += BiasPtr != nullptr ? NCols4 : 0;
@@ -768,18 +768,18 @@ Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx(
768768
for (size_t k = 0; k < BlockCountK; k++) {
769769
// count # of tiles plus blks of the current tile from top
770770
const size_t tile_count = col / GemmFloatKernelWidth16;
771-
SafeInt<size_t> offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16;
772-
float* dst_ptr = FpData + offset.Value();
771+
size_t offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16;
772+
float* dst_ptr = FpData + offset;
773773
if (col % GemmFloatKernelWidth16 >= NCols8) {
774774
// for the second half to 16 width tile
775775
dst_ptr += NCols8;
776776
}
777-
SafeInt<size_t> b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
778-
SafeInt<size_t> b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
779-
SafeInt<size_t> b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
780-
const std::byte* b_data_ptr = QuantBData + b_data_offset.Value();
781-
const float* scale_ptr = QuantBScale + b_scale_offset.Value();
782-
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value();
777+
size_t b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
778+
size_t b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
779+
size_t b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
780+
const std::byte* b_data_ptr = QuantBData + b_data_offset;
781+
const float* scale_ptr = QuantBScale + b_scale_offset;
782+
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset;
783783
bool is_lower = (k % 2) == 0;
784784

785785
__m256i weight_16_epi16[NCols8];
@@ -911,18 +911,18 @@ Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx(
911911
for (size_t k = 0; k < BlockCountK; k++) {
912912
// count # of tiles plus blks of the current tile from top
913913
const size_t tile_count = col / GemmFloatKernelWidth16;
914-
SafeInt<size_t> offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16;
915-
float* dst_ptr = FpData + offset.Value();
914+
size_t offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16;
915+
float* dst_ptr = FpData + offset;
916916
if (col % GemmFloatKernelWidth16 >= NCols8) {
917917
// for the second half to 16 width tile
918918
dst_ptr += NCols8;
919919
}
920-
SafeInt<size_t> b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
921-
SafeInt<size_t> b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
922-
SafeInt<size_t> b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
923-
const std::byte* b_data_ptr = QuantBData + b_data_offset.Value();
924-
const float* scale_ptr = QuantBScale + b_scale_offset.Value();
925-
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value();
920+
size_t b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
921+
size_t b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
922+
size_t b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
923+
const std::byte* b_data_ptr = QuantBData + b_data_offset;
924+
const float* scale_ptr = QuantBScale + b_scale_offset;
925+
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset;
926926
bool is_lower = (k % 2) == 0;
927927

928928
for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) {

0 commit comments

Comments
 (0)