Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#param has_bias
#param has_zero_points
#param is_qualcomm
#param has_weight_idx

#use .getByOffset .setByOffset

Expand Down Expand Up @@ -76,9 +77,12 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
return;
}

#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
let b_value = b.getByOffset(b_weight_offset + b_global*uniforms.K16+kidx_v+col);
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
#else
let b_value = b.getByOffset(b_global * uniforms.K16+kidx_v + col);
#endif
let block_idx = kidx_v/(block_size/16);
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero);
Expand All @@ -100,8 +104,13 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
return;
}

#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
let b_value = b.getByOffset(b_weight_offset + b_global*uniforms.K16+kidx_v+col);
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
#else
const b_weight_offset : u32 = 0;
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
#endif
tile_B[col][row] = AlignWithZeroPoint(b_value);
if (col == 0)
{
Expand All @@ -124,8 +133,14 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
return;
}
#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
#else
const b_weight_offset : u32 = 0;
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
#endif
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
let b_value = b.getByOffset(b_weight_offset + b_global*uniforms.K16+kidx_v+col);
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
let block_idx = kidx_v/(block_size/16);
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
Expand All @@ -137,9 +152,6 @@ $MAIN {
#if n_bits == 2
LoadDequantizationTable(local_idx);
workgroupBarrier();
#endif
#if has_bias
let b_bias_offset = uniforms.weight_idx * uniforms.N;
#endif
// During the load phase we use all 256 threads to load 64 rows of A/B.
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
Expand Down Expand Up @@ -370,6 +382,13 @@ $MAIN {
let a_global = a_global_base + base_A + a_idx;
let b_global = b_global_base + base_B;
let output_idx = ((a_global) * uniforms.N + b_global)/4;
#if has_bias
#if has_weight_idx
let b_bias_offset = uniforms.weight_idx * uniforms.N;
#else
const b_bias_offset : u32 = 0;
#endif
#endif
// This creates a shader requirement that uniforms.N % 16 == 0
if (a_global < uniforms.M && b_global < uniforms.N)
{
Expand Down Expand Up @@ -411,6 +430,7 @@ $MAIN {
#endif
#else
#if has_bias
// TODO: wanted to use vec4 for bias but for some reason that fails ut. Later.
let bias_vec1 = vec4<output_element_t>(
bias[b_global + 0 + b_bias_offset],
bias[b_global + 1 + b_bias_offset],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

#param n_bits
#param has_bias
#param has_zero_points

#include "quantization/matmul_nbits_zero_pt.wgsl.template"
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template",
WGSL_TEMPLATE_PARAMETER(block_size, block_size_),
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
Expand Down Expand Up @@ -65,6 +66,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co

return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
WGSL_TEMPLATE_PARAMETER(output_type_i32, true),
Expand Down Expand Up @@ -109,8 +111,10 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{&a_scale, ProgramTensorMetadataDependency::Rank, 1}})
.AddUniformVariable({M * K / kU32Components});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

const bool has_zero_points = zero_points != nullptr;
const bool has_bias = bias != nullptr;
const bool has_weight_idx = weight_index != 0;
const bool single_scale_weights = (block_size == K * N);
if (M < min_M_for_tile_optimization) {
uint32_t tile_size_k_vec = 16;
Expand All @@ -121,7 +125,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
tile_size_n = 4;
}
const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components);
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, single_scale_weights};
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights};
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(M * num_N_tile);
Expand All @@ -131,7 +135,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias);
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx);
if (has_zero_points) {
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
Expand All @@ -146,7 +150,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, is_qualcomm};
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm};
mul_program.SetWorkgroupSize(256);
mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
Expand All @@ -162,7 +166,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{zero_blocks_per_col},
{weight_index}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)})
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias);
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx);
if (has_zero_points) {
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
Expand Down
32 changes: 19 additions & 13 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
public:
DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits,
bool has_zero_points, bool has_bias, bool is_qualcomm) : Program{"DP4AMatMulNBits"},
block_size_(block_size),
nbits_(nbits),
has_bias_(has_bias),
has_zero_points_(has_zero_points),
is_qualcomm_(is_qualcomm) {}
bool has_zero_points, bool has_bias,
bool has_weight_idx, bool is_qualcomm) : Program{"DP4AMatMulNBits"},
block_size_(block_size),
nbits_(nbits),
has_bias_(has_bias),
has_zero_points_(has_zero_points),
has_weight_idx_(has_weight_idx),
is_qualcomm_(is_qualcomm) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
Expand All @@ -44,19 +46,22 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
uint32_t nbits_;
bool has_bias_;
bool has_zero_points_;
bool has_weight_idx_;
bool is_qualcomm_;
};

class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {
public:
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits,
bool has_zero_points, bool has_bias, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"},
tile_size_k_vec_(tile_size_k_vec),
tile_size_(tile_size),
nbits_(nbits),
has_bias_(has_bias),
has_zero_points_(has_zero_points),
single_scale_weights_(single_scale_weights) {}
bool has_zero_points, bool has_bias,
bool has_weight_idx, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"},
tile_size_k_vec_(tile_size_k_vec),
tile_size_(tile_size),
nbits_(nbits),
has_bias_(has_bias),
has_zero_points_(has_zero_points),
has_weight_idx_(has_weight_idx),
single_scale_weights_(single_scale_weights) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
Expand All @@ -75,6 +80,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
uint32_t nbits_;
bool has_bias_;
bool has_zero_points_;
bool has_weight_idx_;
bool single_scale_weights_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#param n_bits
#param has_zero_points
#param has_bias
#param has_weight_idx

#use .getByOffset .setByOffset

Expand Down Expand Up @@ -67,7 +68,11 @@ $MAIN {
let local_row = local_idx / tile_size_k_vec;

#if has_bias
#if has_weight_idx
let b_bias_offset = uniforms.weight_idx * uniforms.N;
#else
const b_bias_offset : u32 = 0;
#endif
#endif

#if n_bits == 2
Expand All @@ -78,8 +83,11 @@ $MAIN {
#endif
#if single_scale_weights
let zero = mm_read_zero(0, 0, uniforms.N, uniforms.zero_blocks_per_col);
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K / uniforms.block_size);
let own_scale_b = scales_b.getByOffset(b_scale_offset);
#if has_weight_idx
let own_scale_b = scales_b.getByOffset(uniforms.weight_idx);
#else
let own_scale_b = scales_b.getByOffset(0);
#endif
#endif

for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec)
Expand All @@ -101,12 +109,17 @@ $MAIN {
let b_global = b_global_base + row_offset + local_row;
if (b_global < uniforms.N && k_offset < uniforms.K32)
{
#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K32;
let b_offset = b_weight_offset + b_global * uniforms.K32 + k_offset;
#if !single_scale_weights
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K / uniforms.block_size);
let own_scale_b = scales_b.getByOffset(b_scale_offset + b_global * uniforms.K / uniforms.block_size + block_idx);
#else
let b_offset = b_global * uniforms.K32 + k_offset;
let own_scale_b = scales_b.getByOffset(b_global * uniforms.K / uniforms.block_size + block_idx);
#endif
#if !single_scale_weights
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
#endif
#if n_bits == 4
let b_value = b.getByOffset(b_offset);
Expand Down
Loading
Loading