Skip to content

Commit 3dc9385

Browse files
committed
resolve conflict
Signed-off-by: wxsIcey <[email protected]>
1 parent ac08c70 commit 3dc9385

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
123123
AscendDeviceType, ProfileExecuteDuration,
124124
enable_sp, get_ascend_device_type, is_enable_nz,
125-
is_moe_model, lmhead_tp_enable)
125+
is_moe_model, is_vl_model, lmhead_tp_enable)
126126
from vllm_ascend.worker.npu_input_batch import InputBatch
127127

128128
if TYPE_CHECKING:
@@ -270,6 +270,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
270270

271271
set_cos_and_sin(vllm_config, self.max_num_reqs,
272272
self.uniform_decode_query_len, self.dtype, self.device)
273+
if not is_vl_model(self.vllm_config
274+
) and not self.vllm_config.model_config.use_mla:
275+
initialize_cos_sin(self.vllm_config, self.dtype, self.device)
273276
set_mc2_tokens_capacity(vllm_config, self.max_num_reqs,
274277
self.uniform_decode_query_len)
275278
set_mc2_mask(vllm_config, self.device)
@@ -2198,6 +2201,9 @@ def _dummy_run(
21982201
else:
21992202
positions = self.positions.gpu[:num_tokens_padded]
22002203

2204+
# update global cos, sin
2205+
update_cos_sin(positions)
2206+
22012207
if get_pp_group().is_first_rank:
22022208
intermediate_tensors = None
22032209
else:

0 commit comments

Comments
 (0)