Skip to content

Commit cb596f5

Browse files
committed
fix nz for quantization
Signed-off-by: wangxiyuan <[email protected]>
1 parent 5932abc commit cb596f5

File tree

4 files changed

+11
-14
lines changed

4 files changed

+11
-14
lines changed

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm_ascend.ascend_config import get_ascend_config
2828
from vllm_ascend.distributed.parallel_state import get_mc2_group
2929
from vllm_ascend.ops.moe.experts_selector import select_experts
30-
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
30+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
3131

3232

3333
class AscendW4A8DynamicLinearMethod:
@@ -482,10 +482,9 @@ def process_weights_after_loading(self, layer):
482482

483483
self.update_bias(layer, w13_bias, w2_bias)
484484

485-
if is_enable_nz():
486-
layer.w13_weight.data = torch_npu.npu_format_cast(
487-
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
488-
layer.w2_weight.data = torch_npu.npu_format_cast(
489-
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
485+
layer.w13_weight.data = torch_npu.npu_format_cast(
486+
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
487+
layer.w2_weight.data = torch_npu.npu_format_cast(
488+
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
490489
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
491490
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

vllm_ascend/quantization/w8a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def process_weights_after_loading(self, layer):
347347
# converting ACL_FORMAT_FRACTAL_NZ.
348348
# npu_quant_grouped_matmul_dequant in eager mode does not accept
349349
# ACL_FORMAT_FRACTAL_NZ.
350-
if not is_310p() and is_enable_nz():
350+
if not is_310p():
351351
layer.w13_weight.data = torch_npu.npu_format_cast(
352352
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
353353
layer.w2_weight.data = torch_npu.npu_format_cast(

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,8 @@ def process_weights_after_loading(self, layer):
270270
1, 2).contiguous()
271271
layer.w2_weight.data = layer.w2_weight.data.transpose(
272272
1, 2).contiguous()
273-
if is_enable_nz():
274-
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
275-
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
273+
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
274+
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
276275
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
277276
layer.w13_weight_scale.data.shape[0], -1)
278277
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

vllm_ascend/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,9 @@ def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8,
8181
"vllm_config must be provided when _ENABLE_NZ is None")
8282
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
8383

84-
_IS_EAGLE_MODE = (
85-
vllm_config.speculative_config is not None and
86-
getattr(vllm_config.speculative_config, 'method', None) in ("eagle", "eagle3")
87-
)
84+
_IS_EAGLE_MODE = (vllm_config.speculative_config is not None
85+
and getattr(vllm_config.speculative_config, 'method',
86+
None) in ("eagle", "eagle3"))
8887

8988
if dtype in [torch.float16, torch.bfloat16]:
9089
return _ENABLE_NZ if _IS_EAGLE_MODE else False

0 commit comments

Comments
 (0)