Skip to content

Commit 20568f7

Browse files
committed
tiny fix
Signed-off-by: wxsIcey <[email protected]>
1 parent be584fa commit 20568f7

File tree

2 files changed

+2
-60
lines changed

2 files changed

+2
-60
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def set_ascend_forward_context(
7272
prefetch_stream: torch.npu.Stream = None,
7373
model_instance: torch.nn.Module = None,
7474
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
75-
cos: Optional[torch.Tensor] = None,
76-
sin: Optional[torch.Tensor] = None,
7775
is_mtp_model=False):
7876
"""A context manager that stores the current forward context,
7977
can be attention metadata, etc.
@@ -162,10 +160,6 @@ def set_ascend_forward_context(
162160
forward_context.weight_prefetch_method = weight_prefetch_method
163161
forward_context.is_mtp_model = is_mtp_model
164162

165-
# initialize rope
166-
forward_context.cos = cos
167-
forward_context.sin = sin
168-
169163
if num_tokens is None and attn_metadata is not None:
170164
num_tokens = attn_metadata.num_actual_tokens
171165

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
415415
rope_dim,
416416
dtype=self.dtype,
417417
device=self.device)
418-
# For GQA models.
419-
elif not self.vllm_config.model_config.use_mla:
420-
self.cos = torch.ones(1,
421-
self.max_num_tokens,
422-
1,
423-
128,
424-
dtype=self.dtype,
425-
device=self.device)
426-
self.sin = torch.zeros(1,
427-
self.max_num_tokens,
428-
1,
429-
128,
430-
dtype=self.dtype,
431-
device=self.device)
432418
else:
433419
self.cos = None
434420
self.sin = None
@@ -2508,22 +2494,6 @@ def execute_model(
25082494
aclgraph_runtime_mode, batch_descriptor = \
25092495
self.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
25102496

2511-
# initialize rope
2512-
cos_sin_cache = self.model.model.layers[
2513-
self.model.model.
2514-
start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(
2515-
0, positions)
2516-
last_dim = cos_sin_cache.size()[-1]
2517-
cos, sin = cos_sin_cache.reshape(-1, 2,
2518-
last_dim // 2).repeat(1, 1,
2519-
2).chunk(2,
2520-
dim=-2)
2521-
# BSNH
2522-
self.cos[:, :maybe_padded_num_tokens] = cos.view(
2523-
1, -1, 1, last_dim).contiguous()
2524-
self.sin[:, :maybe_padded_num_tokens] = sin.view(
2525-
1, -1, 1, last_dim).contiguous()
2526-
25272497
# Run forward pass
25282498
with ProfileExecuteDuration().capture_async("forward"):
25292499
with set_ascend_forward_context(
@@ -2540,11 +2510,7 @@ def execute_model(
25402510
total_num_scheduled_tokens,
25412511
prefetch_stream=self.prefetch_stream,
25422512
model_instance=self.model,
2543-
weight_prefetch_method=self.weight_prefetch_method,
2544-
cos=self.cos[:, :maybe_padded_num_tokens]
2545-
if self.cos is not None else None,
2546-
sin=self.sin[:, :maybe_padded_num_tokens]
2547-
if self.sin is not None else None):
2513+
weight_prefetch_method=self.weight_prefetch_method):
25482514
self.maybe_setup_kv_connector(scheduler_output)
25492515

25502516
hidden_states = self._generate_process_reqs_hidden_states(
@@ -3252,20 +3218,6 @@ def dummy_drafter_compute_logits(hidden_states):
32523218
return self.drafter.model.compute_logits(
32533219
hidden_states[dummy_indices])
32543220

3255-
# initialize rope
3256-
cos_sin_cache = self.model.model.layers[
3257-
self.model.model.
3258-
start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(
3259-
0, positions)
3260-
last_dim = cos_sin_cache.size()[-1]
3261-
cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat(
3262-
1, 1, 2).chunk(2, dim=-2)
3263-
# BSNH
3264-
self.cos[:, :num_tokens] = cos.view(1, -1, 1,
3265-
last_dim).contiguous()
3266-
self.sin[:, :num_tokens] = sin.view(1, -1, 1,
3267-
last_dim).contiguous()
3268-
32693221
with set_ascend_forward_context(
32703222
attn_metadata,
32713223
self.vllm_config,
@@ -3280,11 +3232,7 @@ def dummy_drafter_compute_logits(hidden_states):
32803232
batch_descriptor=batch_descriptor,
32813233
prefetch_stream=self.prefetch_stream,
32823234
model_instance=self.model,
3283-
weight_prefetch_method=self.weight_prefetch_method,
3284-
cos=self.cos[:, :num_tokens]
3285-
if self.cos is not None else None,
3286-
sin=self.sin[:, :num_tokens]
3287-
if self.sin is not None else None):
3235+
weight_prefetch_method=self.weight_prefetch_method):
32883236
hidden_states = self._generate_dummy_run_hidden_states(
32893237
input_ids, positions, num_tokens_padded,
32903238
intermediate_tensors, inputs_embeds)

0 commit comments

Comments
 (0)