@@ -27,9 +27,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
2727 if (has_zero_points_) {
2828 shader.AddInput (" zero_points" , ShaderUsage::UseUniform);
2929 }
30+ if (has_bias_) {
31+ shader.AddInput (" bias" , ShaderUsage::UseUniform);
32+ }
3033 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
3134 return WGSL_TEMPLATE_APPLY (shader, " quantization/dp4a_matmul.wgsl.template" ,
3235 WGSL_TEMPLATE_PARAMETER (block_size, block_size_),
36+ WGSL_TEMPLATE_PARAMETER (has_bias, has_bias_),
37+ WGSL_TEMPLATE_PARAMETER (has_weight_idx, has_weight_idx_),
3338 WGSL_TEMPLATE_PARAMETER (has_zero_points, has_zero_points_),
3439 WGSL_TEMPLATE_PARAMETER (is_qualcomm, is_qualcomm_),
3540 WGSL_TEMPLATE_PARAMETER (n_bits, nbits_),
@@ -50,13 +55,18 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
5055 if (has_zero_points_) {
5156 shader.AddInput (" zero_points" , ShaderUsage::UseUniform);
5257 }
58+ if (has_bias_) {
59+ shader.AddInput (" bias" , ShaderUsage::UseUniform);
60+ }
5361 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
5462
5563 ORT_ENFORCE (WorkgroupSizeX () % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0 , " tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4" );
5664 const uint32_t sub_tile_count = WorkgroupSizeX () / tile_size_k_vec_;
5765 ORT_ENFORCE (tile_size_ % sub_tile_count == 0 , " tile_size_ must be divisible by sub_tile_count" );
5866
5967 return WGSL_TEMPLATE_APPLY (shader, " quantization/dp4a_matmul_small_m.wgsl.template" ,
68+ WGSL_TEMPLATE_PARAMETER (has_bias, has_bias_),
69+ WGSL_TEMPLATE_PARAMETER (has_weight_idx, has_weight_idx_),
6070 WGSL_TEMPLATE_PARAMETER (has_zero_points, has_zero_points_),
6171 WGSL_TEMPLATE_PARAMETER (n_bits, nbits_),
6272 WGSL_TEMPLATE_PARAMETER (output_type_i32, true ),
@@ -72,7 +82,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
7282}
7383
7484Status ApplyDP4AMatrixMatMulNBits (const Tensor* a, const Tensor* b, const Tensor* scales,
75- const Tensor* zero_points,
85+ const Tensor* zero_points, const Tensor* bias,
7686 uint32_t M,
7787 uint32_t N,
7888 uint32_t K,
@@ -81,7 +91,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
8191 uint32_t min_M_for_tile_optimization,
8292 uint32_t nbits,
8393 onnxruntime::webgpu::ComputeContext& context,
84- Tensor* y) {
94+ Tensor* y,
95+ const uint32_t weight_index) {
8596 constexpr uint32_t kVec4Components = 4 ;
8697 constexpr uint32_t kVec2Components = 2 ;
8798 constexpr uint32_t kU32Components = 4 ;
@@ -100,7 +111,10 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
100111 {&a_scale, ProgramTensorMetadataDependency::Rank, 1 }})
101112 .AddUniformVariable ({M * K / kU32Components });
102113 ORT_RETURN_IF_ERROR (context.RunProgram (quantize_program));
114+
103115 const bool has_zero_points = zero_points != nullptr ;
116+ const bool has_bias = bias != nullptr ;
117+ const bool has_weight_idx = weight_index != 0 ;
104118 const bool single_scale_weights = (block_size == K * N);
105119 if (M < min_M_for_tile_optimization) {
106120 uint32_t tile_size_k_vec = 16 ;
@@ -111,20 +125,23 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
111125 tile_size_n = 4 ;
112126 }
113127 const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components );
114- DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, single_scale_weights};
128+ DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights};
115129 uint32_t num_N_tile = (N + tile_size_n - 1 ) / tile_size_n;
116130 mul_program.SetWorkgroupSize (128 );
117131 mul_program.SetDispatchGroupSize (M * num_N_tile);
118132 mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec4Components )},
119133 {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1 },
120134 {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(b_components * kU32Components )},
121135 {scales, ProgramTensorMetadataDependency::TypeAndRank, 1 }})
122- .AddUniformVariables ({M, N, K, K / 16 , K / 32 , block_size, num_N_tile, zero_blocks_per_col})
136+ .AddUniformVariables ({M, N, K, K / 16 , K / 32 , block_size, num_N_tile, zero_blocks_per_col, weight_index })
123137 .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, 1 })
124- .CacheHint (nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights);
138+ .CacheHint (nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx );
125139 if (has_zero_points) {
126140 mul_program.AddInput ({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape ().Size () + 3 ) / 4 }, 4 });
127141 }
142+ if (has_bias) {
143+ mul_program.AddInput ({bias, ProgramTensorMetadataDependency::None});
144+ }
128145 return context.RunProgram (mul_program);
129146 }
130147
@@ -133,7 +150,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
133150 uint32_t num_M_tile = (M + kTileSize - 1 ) / kTileSize ;
134151 uint32_t num_N_tile = (N + kTileSize - 1 ) / kTileSize ;
135152 bool is_qualcomm = context.AdapterInfo ().vendor == std::string_view{" qualcomm" };
136- DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm};
153+ DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm};
137154 mul_program.SetWorkgroupSize (256 );
138155 mul_program.SetDispatchGroupSize (num_M_tile * num_N_tile);
139156 mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec4Components )},
@@ -146,12 +163,16 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
146163 {static_cast <uint32_t >(K / 8 )},
147164 {static_cast <uint32_t >(K / 16 )},
148165 {num_N_tile},
149- {zero_blocks_per_col}})
166+ {zero_blocks_per_col},
167+ {weight_index}})
150168 .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast <int >(kVec4Components )})
151- .CacheHint (" Block" + std::to_string (block_size), nbits, has_zero_points, is_qualcomm);
169+ .CacheHint (" Block" + std::to_string (block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx );
152170 if (has_zero_points) {
153171 mul_program.AddInput ({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape ().Size () + 3 ) / 4 }, 4 });
154172 }
173+ if (has_bias) {
174+ mul_program.AddInput ({bias, ProgramTensorMetadataDependency::None});
175+ }
155176 return context.RunProgram (mul_program);
156177}
157178
0 commit comments