Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET} safeint_interface)

target_compile_definitions(${mlas_target} PRIVATE ${mlas_private_compile_definitions})

Expand Down
94 changes: 47 additions & 47 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ QNBitGemmPackQuantBDataSize_Lasx(
BlkSumSize += SafeInt<size_t>(BlkSumAlignment) - 1;

PackedQuantBDataSize += ScaleSize + BlkSumSize;
return PackedQuantBDataSize.Value();
return static_cast<size_t>(PackedQuantBDataSize);
} else {
SafeInt<size_t> PackedQuantBDataSize = SafeInt<size_t>(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
return PackedQuantBDataSize.Value();
return static_cast<size_t>(PackedQuantBDataSize);
}
}

Expand All @@ -73,7 +73,7 @@ SQ4BitGemmPackQuantBData_Lasx(

const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const SafeInt<size_t> Iterations = SafeInt<size_t>(N) * BlockCountK; // one iteration per block
const size_t Iterations = SafeInt<size_t>(N) * BlockCountK; // one iteration per block

size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64);

Expand Down Expand Up @@ -105,14 +105,14 @@ SQ4BitGemmPackQuantBData_Lasx(
//

MlasTrySimpleParallel(
ThreadPool, Iterations.Value(),
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
const size_t n = tid / BlockCountK;
const size_t k_blk = tid % BlockCountK;

const SafeInt<size_t> data_offset = SafeInt<size_t>(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
const std::byte* QuantBData = QuantBDataBegin + data_offset.Value();
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset.Value();
const size_t data_offset = SafeInt<size_t>(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
const std::byte* QuantBData = QuantBDataBegin + data_offset;
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;

for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) {
for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) {
Expand Down Expand Up @@ -163,8 +163,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum_Lasx(
}

if (QuantBScaleBegin) {
SafeInt<size_t> offset = SafeInt<size_t>(N) * BlockCountK;
std::copy(QuantBScaleBegin, QuantBScaleBegin + offset.Value(), packed_quant_b.PackedQuantBScale);
size_t offset = SafeInt<size_t>(N) * BlockCountK;
std::copy(QuantBScaleBegin, QuantBScaleBegin + offset, packed_quant_b.PackedQuantBScale);
}

if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) {
Expand Down Expand Up @@ -272,14 +272,14 @@ ComputeDotProducts_BlkLen32Plus_CompFp32_lasx(

float scale_v[NCols];
UnrolledLoop<NCols>([&](size_t i) {
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
scale_v[i] = *(s + scale_offset.Value());
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
scale_v[i] = *(s + scale_offset);
});

std::byte* b_blk_data_col_ptr[NCols];
UnrolledLoop<NCols>([&](size_t i) {
SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * i;
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value());
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * i;
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset);
});

// not ready for "Manual conversion to float" in neon yet.
Expand Down Expand Up @@ -427,14 +427,14 @@ ComputeDotProducts_BlkLen16_CompFp32_lasx(

float scale_v[NCols];
UnrolledLoop<NCols>([&](size_t i) {
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
scale_v[i] = *(s + scale_offset.Value());
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * i;
scale_v[i] = *(s + scale_offset);
});

std::byte* b_blk_data_col_ptr[NCols];
UnrolledLoop<NCols>([&](size_t i) {
SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * i;
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value());
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * i;
b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset);
});

if constexpr (HasZeroPoint) {
Expand Down Expand Up @@ -551,7 +551,7 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx(

float* SumPtr = CRowPtr;

int64_t nblk = <int64_t>(CountN) - NCols4;
int64_t nblk = static_cast<int64_t>(CountN - NCols4);
while (nblk >= 0) {
ComputeDotProducts_BlkLen16_CompFp32_lasx<NCols4, HasZeroPoint>(
BlkLen16,
Expand All @@ -560,13 +560,13 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx(
BiasPtr
);

SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
QuantBDataColPtr += data_offset.Value();
QuantBScaleColPtr += scale_offset.Value();
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
QuantBDataColPtr += data_offset;
QuantBScaleColPtr += scale_offset;
if constexpr (HasZeroPoint) {
SafeInt<size_t> zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
QuantBZeroPointColPtr += zeropoint_offset.Value();
size_t zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
QuantBZeroPointColPtr += zeropoint_offset;
}

BiasPtr += BiasPtr != nullptr ? NCols4 : 0;
Expand Down Expand Up @@ -650,13 +650,13 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx(
);
}

SafeInt<size_t> data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
SafeInt<size_t> scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
QuantBDataColPtr += data_offset.Value();
QuantBScaleColPtr += scale_offset.Value();
size_t data_offset = SafeInt<size_t>(StrideQuantBData) * NCols4;
size_t scale_offset = SafeInt<size_t>(StrideQuantBScale) * NCols4;
QuantBDataColPtr += data_offset;
QuantBScaleColPtr += scale_offset;
if constexpr (HasZeroPoint) {
SafeInt<size_t> zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
QuantBZeroPointColPtr += zeropoint_offset.Value();
size_t zeropoint_offset = SafeInt<size_t>(StrideQuantBZeroPoint) * NCols4;
QuantBZeroPointColPtr += zeropoint_offset;
}

BiasPtr += BiasPtr != nullptr ? NCols4 : 0;
Expand Down Expand Up @@ -768,18 +768,18 @@ Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx(
for (size_t k = 0; k < BlockCountK; k++) {
// count # of tiles plus blks of the current tile from top
const size_t tile_count = col / GemmFloatKernelWidth16;
SafeInt<size_t> offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16;
float* dst_ptr = FpData + offset.Value();
size_t offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16;
float* dst_ptr = FpData + offset;
if (col % GemmFloatKernelWidth16 >= NCols8) {
// for the second half to 16 width tile
dst_ptr += NCols8;
}
SafeInt<size_t> b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
SafeInt<size_t> b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
SafeInt<size_t> b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
const std::byte* b_data_ptr = QuantBData + b_data_offset.Value();
const float* scale_ptr = QuantBScale + b_scale_offset.Value();
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value();
size_t b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
size_t b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
size_t b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
const std::byte* b_data_ptr = QuantBData + b_data_offset;
const float* scale_ptr = QuantBScale + b_scale_offset;
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset;
bool is_lower = (k % 2) == 0;

__m256i weight_16_epi16[NCols8];
Expand Down Expand Up @@ -910,19 +910,19 @@ Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx(
const size_t cols = std::min(NCols8, CountN - col);
for (size_t k = 0; k < BlockCountK; k++) {
// count # of tiles plus blks of the current tile from top
const size_t tile_count = col / GemmFloatKernelWidth16;
SafeInt<size_t> offset = SafeInt<size_t>(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16;
float* dst_ptr = FpData + offset.Value();
const SafeInt<size_t> tile_count = col / GemmFloatKernelWidth16;
size_t offset = tile_count * CountK + k * BlkLen * GemmFloatKernelWidth16;
float* dst_ptr = FpData + offset;
if (col % GemmFloatKernelWidth16 >= NCols8) {
// for the second half to 16 width tile
dst_ptr += NCols8;
}
SafeInt<size_t> b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
SafeInt<size_t> b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
SafeInt<size_t> b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
const std::byte* b_data_ptr = QuantBData + b_data_offset.Value();
const float* scale_ptr = QuantBScale + b_scale_offset.Value();
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value();
size_t b_data_offset = SafeInt<size_t>(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes;
size_t b_scale_offset = SafeInt<size_t>(col) * BlockCountK + k;
size_t b_zp_offset = SafeInt<size_t>(col) * zp_col_stride_in_bytes + k / 2;
const std::byte* b_data_ptr = QuantBData + b_data_offset;
const float* scale_ptr = QuantBScale + b_scale_offset;
const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset;
bool is_lower = (k % 2) == 0;

for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) {
Expand Down
Loading
Loading