Skip to content

Commit 3d44523

Browse files
committed
fix nz for quantization
Signed-off-by: wangxiyuan <[email protected]>
1 parent ceadc27 commit 3d44523

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm_ascend.ascend_config import get_ascend_config
2828
from vllm_ascend.distributed.parallel_state import get_mc2_group
2929
from vllm_ascend.ops.moe.experts_selector import select_experts
30-
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
30+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
3131

3232

3333
class AscendW4A8DynamicLinearMethod:
@@ -482,10 +482,9 @@ def process_weights_after_loading(self, layer):
482482

483483
self.update_bias(layer, w13_bias, w2_bias)
484484

485-
if is_enable_nz():
486-
layer.w13_weight.data = torch_npu.npu_format_cast(
487-
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
488-
layer.w2_weight.data = torch_npu.npu_format_cast(
489-
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
485+
layer.w13_weight.data = torch_npu.npu_format_cast(
486+
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
487+
layer.w2_weight.data = torch_npu.npu_format_cast(
488+
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
490489
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
491490
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

vllm_ascend/quantization/w8a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def process_weights_after_loading(self, layer):
347347
# converting ACL_FORMAT_FRACTAL_NZ.
348348
# npu_quant_grouped_matmul_dequant in eager mode does not accept
349349
# ACL_FORMAT_FRACTAL_NZ.
350-
if not is_310p() and is_enable_nz():
350+
if not is_310p():
351351
layer.w13_weight.data = torch_npu.npu_format_cast(
352352
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
353353
layer.w2_weight.data = torch_npu.npu_format_cast(

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,8 @@ def process_weights_after_loading(self, layer):
270270
1, 2).contiguous()
271271
layer.w2_weight.data = layer.w2_weight.data.transpose(
272272
1, 2).contiguous()
273-
if is_enable_nz():
274-
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
275-
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
273+
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
274+
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
276275
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
277276
layer.w13_weight_scale.data.shape[0], -1)
278277
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

0 commit comments

Comments
 (0)