Skip to content

Commit 423a03f

Browse files
guschmueCopilot
andauthored
webgpu / nbitmm support for bias and weight_index (#26392)
add support for bias and weight_index, move subgroup_matrix_matmul_nbits to template and make program callable from other ops. --------- Co-authored-by: Copilot <[email protected]>
1 parent 3d926ac commit 423a03f

14 files changed

+625
-284
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
#param block_size
55
#param n_bits
6+
#param has_bias
67
#param has_zero_points
78
#param is_qualcomm
9+
#param has_weight_idx
810

911
#use .getByOffset .setByOffset
1012

@@ -75,15 +77,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
7577
{
7678
return;
7779
}
78-
79-
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
80+
#if has_weight_idx
81+
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
82+
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
83+
#else
84+
let b_value = b.getByOffset(b_global * uniforms.K16+kidx_v + col);
85+
#endif
8086
let block_idx = kidx_v/(block_size/16);
8187
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
8288
tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero);
8389
if (col == 0)
8490
{
8591
// kidx_v - each kidx_v covers 16 values of k
86-
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
92+
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
93+
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
8794
}
8895
}
8996
#endif
@@ -97,13 +104,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
97104
return;
98105
}
99106

100-
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
107+
#if has_weight_idx
108+
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
109+
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
110+
#else
111+
const b_weight_offset : u32 = 0;
112+
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
113+
#endif
101114
tile_B[col][row] = AlignWithZeroPoint(b_value);
102115
if (col == 0)
103116
{
104117
// kidx_v - each kidx_v covers 16 values of k
105118
let block_idx = kidx_v/(block_size/16);
106-
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
119+
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
120+
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
107121
#if has_zero_points
108122
zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
109123
#endif
@@ -119,10 +133,17 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
119133
{
120134
return;
121135
}
122-
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
136+
#if has_weight_idx
137+
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
138+
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
139+
#else
140+
const b_weight_offset : u32 = 0;
141+
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
142+
#endif
123143
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
124144
let block_idx = kidx_v/(block_size/16);
125-
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
145+
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
146+
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
126147
}
127148
#endif
128149

@@ -360,19 +381,89 @@ $MAIN {
360381
let a_global = a_global_base + base_A + a_idx;
361382
let b_global = b_global_base + base_B;
362383
let output_idx = ((a_global) * uniforms.N + b_global)/4;
384+
#if has_bias
385+
#if has_weight_idx
386+
let b_bias_offset = uniforms.weight_idx * uniforms.N;
387+
#else
388+
const b_bias_offset : u32 = 0;
389+
#endif
390+
#endif
363391
// This creates a shader requirement that uniforms.N % 16 == 0
364392
if (a_global < uniforms.M && b_global < uniforms.N)
365393
{
366394
#if is_qualcomm
395+
#if has_bias
396+
let bias_vec1 = vec4<output_element_t>(
397+
bias[b_global + 0 + b_bias_offset],
398+
bias[b_global + 1 + b_bias_offset],
399+
bias[b_global + 2 + b_bias_offset],
400+
bias[b_global + 3 + b_bias_offset]
401+
);
402+
let bias_vec2 = vec4<output_element_t>(
403+
bias[b_global + 4 + b_bias_offset],
404+
bias[b_global + 5 + b_bias_offset],
405+
bias[b_global + 6 + b_bias_offset],
406+
bias[b_global + 7 + b_bias_offset]
407+
);
408+
let bias_vec3 = vec4<output_element_t>(
409+
bias[b_global + 8 + b_bias_offset],
410+
bias[b_global + 9 + b_bias_offset],
411+
bias[b_global + 10 + b_bias_offset],
412+
bias[b_global + 11 + b_bias_offset]
413+
);
414+
let bias_vec4 = vec4<output_element_t>(
415+
bias[b_global + 12 + b_bias_offset],
416+
bias[b_global + 13 + b_bias_offset],
417+
bias[b_global + 14 + b_bias_offset],
418+
bias[b_global + 15 + b_bias_offset]
419+
);
420+
output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]) + bias_vec1);
421+
output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]) + bias_vec2);
422+
output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]) + bias_vec3);
423+
output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]) + bias_vec4);
424+
#else
367425
output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]));
368426
output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]));
369427
output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]));
370428
output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]));
429+
#endif
371430
#else
431+
#if has_bias
432+
// TODO: wanted to use vec4 for bias but for some reason that fails ut. Later.
433+
let bias_vec1 = vec4<output_element_t>(
434+
bias[b_global + 0 + b_bias_offset],
435+
bias[b_global + 1 + b_bias_offset],
436+
bias[b_global + 2 + b_bias_offset],
437+
bias[b_global + 3 + b_bias_offset]
438+
);
439+
let bias_vec2 = vec4<output_element_t>(
440+
bias[b_global + 4 + b_bias_offset],
441+
bias[b_global + 5 + b_bias_offset],
442+
bias[b_global + 6 + b_bias_offset],
443+
bias[b_global + 7 + b_bias_offset]
444+
);
445+
let bias_vec3 = vec4<output_element_t>(
446+
bias[b_global + 8 + b_bias_offset],
447+
bias[b_global + 9 + b_bias_offset],
448+
bias[b_global + 10 + b_bias_offset],
449+
bias[b_global + 11 + b_bias_offset]
450+
);
451+
let bias_vec4 = vec4<output_element_t>(
452+
bias[b_global + 12 + b_bias_offset],
453+
bias[b_global + 13 + b_bias_offset],
454+
bias[b_global + 14 + b_bias_offset],
455+
bias[b_global + 15 + b_bias_offset]
456+
);
457+
output.setByOffset(output_idx, lane_output1 + bias_vec1);
458+
output.setByOffset(output_idx+1, lane_output2 + bias_vec2);
459+
output.setByOffset(output_idx+2, lane_output3 + bias_vec3);
460+
output.setByOffset(output_idx+3, lane_output4 + bias_vec4);
461+
#else
372462
output.setByOffset(output_idx, lane_output1);
373463
output.setByOffset(output_idx+1, lane_output2);
374464
output.setByOffset(output_idx+2, lane_output3);
375465
output.setByOffset(output_idx+3, lane_output4);
466+
#endif
376467
#endif
377468
}
378469
} // MAIN

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
2727
if (has_zero_points_) {
2828
shader.AddInput("zero_points", ShaderUsage::UseUniform);
2929
}
30+
if (has_bias_) {
31+
shader.AddInput("bias", ShaderUsage::UseUniform);
32+
}
3033
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
3134
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template",
3235
WGSL_TEMPLATE_PARAMETER(block_size, block_size_),
36+
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
37+
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
3338
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
3439
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
3540
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
@@ -50,13 +55,18 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
5055
if (has_zero_points_) {
5156
shader.AddInput("zero_points", ShaderUsage::UseUniform);
5257
}
58+
if (has_bias_) {
59+
shader.AddInput("bias", ShaderUsage::UseUniform);
60+
}
5361
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
5462

5563
ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4");
5664
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec_;
5765
ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count");
5866

5967
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template",
68+
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
69+
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
6070
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
6171
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
6272
WGSL_TEMPLATE_PARAMETER(output_type_i32, true),
@@ -72,7 +82,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
7282
}
7383

7484
Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
75-
const Tensor* zero_points,
85+
const Tensor* zero_points, const Tensor* bias,
7686
uint32_t M,
7787
uint32_t N,
7888
uint32_t K,
@@ -81,7 +91,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
8191
uint32_t min_M_for_tile_optimization,
8292
uint32_t nbits,
8393
onnxruntime::webgpu::ComputeContext& context,
84-
Tensor* y) {
94+
Tensor* y,
95+
const uint32_t weight_index) {
8596
constexpr uint32_t kVec4Components = 4;
8697
constexpr uint32_t kVec2Components = 2;
8798
constexpr uint32_t kU32Components = 4;
@@ -100,7 +111,10 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
100111
{&a_scale, ProgramTensorMetadataDependency::Rank, 1}})
101112
.AddUniformVariable({M * K / kU32Components});
102113
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
114+
103115
const bool has_zero_points = zero_points != nullptr;
116+
const bool has_bias = bias != nullptr;
117+
const bool has_weight_idx = weight_index != 0;
104118
const bool single_scale_weights = (block_size == K * N);
105119
if (M < min_M_for_tile_optimization) {
106120
uint32_t tile_size_k_vec = 16;
@@ -111,20 +125,23 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
111125
tile_size_n = 4;
112126
}
113127
const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components);
114-
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, single_scale_weights};
128+
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights};
115129
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
116130
mul_program.SetWorkgroupSize(128);
117131
mul_program.SetDispatchGroupSize(M * num_N_tile);
118132
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
119133
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
120134
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(b_components * kU32Components)},
121135
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
122-
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col})
136+
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
123137
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
124-
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights);
138+
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx);
125139
if (has_zero_points) {
126140
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
127141
}
142+
if (has_bias) {
143+
mul_program.AddInput({bias, ProgramTensorMetadataDependency::None});
144+
}
128145
return context.RunProgram(mul_program);
129146
}
130147

@@ -133,7 +150,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
133150
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
134151
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
135152
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
136-
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm};
153+
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm};
137154
mul_program.SetWorkgroupSize(256);
138155
mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile);
139156
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
@@ -146,12 +163,16 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
146163
{static_cast<uint32_t>(K / 8)},
147164
{static_cast<uint32_t>(K / 16)},
148165
{num_N_tile},
149-
{zero_blocks_per_col}})
166+
{zero_blocks_per_col},
167+
{weight_index}})
150168
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)})
151-
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm);
169+
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx);
152170
if (has_zero_points) {
153171
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
154172
}
173+
if (has_bias) {
174+
mul_program.AddInput({bias, ProgramTensorMetadataDependency::None});
175+
}
155176
return context.RunProgram(mul_program);
156177
}
157178

0 commit comments

Comments
 (0)