Skip to content

Commit f983ca3

Browse files
[Bugfix] Fix Qwen2.5-Omni-7B accuarcy test (vllm-project#4556)
### What this PR does / why we need it? Fix Qwen2.5-Omni-7B accuarcy test issue:vllm-project#4480 Depends on : vllm-project#4534 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: hfadzxy <[email protected]>
1 parent bad1ab4 commit f983ca3

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

vllm_ascend/ops/layernorm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ def forward_oot(
108108
residual: Optional[torch.Tensor] = None,
109109
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
110110
import torch_npu
111-
112111
if residual is not None:
113112
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
114113
assert x.size(0) == residual.size(0)
114+
next_need_quant_fusion_linear = getattr(
115+
self, 'next_need_quant_fusion_linear', None)
115116
x, residual = _addrmsnorm_forward_oot(
116-
self, x, residual, self.next_need_quant_fusion_linear,
117-
self.bias)
117+
self, x, residual, next_need_quant_fusion_linear, self.bias)
118118
return x, residual
119119
x, residual = torch_npu.npu_rms_norm(x, self.weight,
120120
self.variance_epsilon)

vllm_ascend/ops/register_custom_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
173173
except AssertionError:
174174
return
175175

176-
if not forward_context.prefetch_mlp_enabled:
176+
prefetch_mlp_enabled = getattr(forward_context, 'prefetch_mlp_enabled',
177+
False)
178+
if not prefetch_mlp_enabled:
177179
return
178180
forward_context.prefetch_mlp_down_proj = True
179181
model_instance = forward_context.model_instance
@@ -202,7 +204,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
202204
except AssertionError:
203205
return
204206

205-
if not forward_context.prefetch_mlp_enabled:
207+
prefetch_mlp_enabled = getattr(forward_context, 'prefetch_mlp_enabled',
208+
False)
209+
if not prefetch_mlp_enabled:
206210
return
207211
if forward_context.prefetch_mlp_gate_up_proj or \
208212
forward_context.prefetch_mlp_down_proj:

0 commit comments

Comments
 (0)