Skip to content

Commit a462fcd

Browse files
author
刘哲续
committed
modify nz in bf16
Signed-off-by: 刘哲续 <[email protected]>
1 parent 7d1db34 commit a462fcd

File tree

8 files changed

+13
-12
lines changed

8 files changed

+13
-12
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
652652

653653
# Function `get_and_maybe_dequant_weights` will cast the weights to
654654
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
655-
if is_enable_nz():
655+
if is_enable_nz(self.kv_b_proj.weight.data.dtype):
656656
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
657657
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
658658

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def pad_qkv_weight(self, data):
284284
dim=2)
285285
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
286286

287-
if is_enable_nz():
287+
if is_enable_nz(qkv_weight_final.dtype):
288288
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
289289
qkv_weight_final)
290290
qkv_weight_final_copy = torch_npu.npu_format_cast(
@@ -300,7 +300,7 @@ def pad_proj_weight(self, data):
300300
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
301301
self.hidden_size, -1)
302302

303-
if is_enable_nz():
303+
if is_enable_nz(out_weight.dtype):
304304
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
305305
out_weight_copy = torch_npu.npu_format_cast(
306306
out_weight_copy, ACL_FORMAT_FRACTAL_ND)

vllm_ascend/models/qwen2_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def pad_qkv_weight(self, data):
268268
dim=2)
269269
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
270270

271-
if is_enable_nz():
271+
if is_enable_nz(qkv_weight_final.dtype):
272272
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
273273
qkv_weight_final)
274274
qkv_weight_final_copy = torch_npu.npu_format_cast(
@@ -284,7 +284,7 @@ def pad_proj_weight(self, data):
284284
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
285285
self.hidden_size, -1)
286286

287-
if is_enable_nz():
287+
if is_enable_nz(out_weight.dtype):
288288
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
289289
out_weight_copy = torch_npu.npu_format_cast(
290290
out_weight_copy, ACL_FORMAT_FRACTAL_ND)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def process_weights_after_loading(self, layer):
8989
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
9090
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
9191

92-
if not is_310p() and is_enable_nz():
92+
if not is_310p() and is_enable_nz(layer.w13_weight.data.type):
9393
layer.w13_weight.data = torch_npu.npu_format_cast(
9494
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
9595
layer.w2_weight.data = torch_npu.npu_format_cast(

vllm_ascend/ops/linear.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
4545

4646
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4747
super().process_weights_after_loading(layer)
48-
if (is_enable_nz() and layer.weight.data.dtype
49-
in [torch.float16, torch.bfloat16]):
48+
if (is_enable_nz(layer.weight.data.dtype)):
5049
layer.weight.data = torch_npu.npu_format_cast(
5150
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
5251

vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def process_weights_after_loading(self, layer):
835835
if self.transpose_weight:
836836
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
837837
# cast quantized weight tensors in NZ format (29) for higher inference speed
838-
if is_enable_nz():
838+
if is_enable_nz(layer.weight.data.dtype):
839839
layer.weight.data = torch_npu.npu_format_cast(
840840
layer.weight.data, 29)
841841
layer.weight_scale.data = layer.weight_scale.data.flatten()

vllm_ascend/torchair/torchair_sfa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
842842
wd_qkv = wd_qkv.t().contiguous()
843843
wd_qkv = transdata(wd_qkv,
844844
block_size=(16, 32)).unsqueeze(0).contiguous()
845-
if is_enable_nz():
845+
if is_enable_nz(wd_qkv.dtype):
846846
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
847847

848848
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
@@ -876,7 +876,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
876876
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
877877
-1)
878878
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
879-
if is_enable_nz():
879+
if is_enable_nz(wu_q.dtype):
880880
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
881881

882882
qb_deq_scl = self.q_proj.deq_scale.data.clone()

vllm_ascend/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ def is_310p():
7171
return _IS_310P
7272

7373

74-
def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
74+
def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, vllm_config: Optional[VllmConfig] = None) -> bool:
7575
global _ENABLE_NZ
7676
if _ENABLE_NZ is None:
7777
if not vllm_config:
7878
raise ValueError(
7979
"vllm_config must be provided when _ENABLE_NZ is None")
8080
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
81+
if dtype in [torch.float16, torch.bfloat16]:
82+
_ENABLE_NZ = 0
8183
return _ENABLE_NZ
8284

8385

0 commit comments

Comments
 (0)