Skip to content

Commit c57cdca

Browse files
committed
tiny fix
Signed-off-by: wxsIcey <[email protected]>
1 parent a41460e commit c57cdca

File tree

2 files changed

+2
-60
lines changed

2 files changed

+2
-60
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def set_ascend_forward_context(
7272
prefetch_stream: torch.npu.Stream = None,
7373
model_instance: torch.nn.Module = None,
7474
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
75-
cos: Optional[torch.Tensor] = None,
76-
sin: Optional[torch.Tensor] = None,
7775
is_mtp_model=False):
7876
"""A context manager that stores the current forward context,
7977
can be attention metadata, etc.
@@ -162,10 +160,6 @@ def set_ascend_forward_context(
162160
forward_context.weight_prefetch_method = weight_prefetch_method
163161
forward_context.is_mtp_model = is_mtp_model
164162

165-
# initialize rope
166-
forward_context.cos = cos
167-
forward_context.sin = sin
168-
169163
if num_tokens is None and attn_metadata is not None:
170164
num_tokens = attn_metadata.num_actual_tokens
171165

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -418,20 +418,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
418418
rope_dim,
419419
dtype=self.dtype,
420420
device=self.device)
421-
# For GQA models.
422-
elif not self.vllm_config.model_config.use_mla:
423-
self.cos = torch.ones(1,
424-
self.max_num_tokens,
425-
1,
426-
128,
427-
dtype=self.dtype,
428-
device=self.device)
429-
self.sin = torch.zeros(1,
430-
self.max_num_tokens,
431-
1,
432-
128,
433-
dtype=self.dtype,
434-
device=self.device)
435421
else:
436422
self.cos = None
437423
self.sin = None
@@ -2530,22 +2516,6 @@ def execute_model(
25302516
aclgraph_runtime_mode, batch_descriptor = \
25312517
self.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
25322518

2533-
# initialize rope
2534-
cos_sin_cache = self.model.model.layers[
2535-
self.model.model.
2536-
start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(
2537-
0, positions)
2538-
last_dim = cos_sin_cache.size()[-1]
2539-
cos, sin = cos_sin_cache.reshape(-1, 2,
2540-
last_dim // 2).repeat(1, 1,
2541-
2).chunk(2,
2542-
dim=-2)
2543-
# BSNH
2544-
self.cos[:, :maybe_padded_num_tokens] = cos.view(
2545-
1, -1, 1, last_dim).contiguous()
2546-
self.sin[:, :maybe_padded_num_tokens] = sin.view(
2547-
1, -1, 1, last_dim).contiguous()
2548-
25492519
# Run forward pass
25502520
with ProfileExecuteDuration().capture_async("forward"):
25512521
with set_ascend_forward_context(
@@ -2562,11 +2532,7 @@ def execute_model(
25622532
total_num_scheduled_tokens,
25632533
prefetch_stream=self.prefetch_stream,
25642534
model_instance=self.model,
2565-
weight_prefetch_method=self.weight_prefetch_method,
2566-
cos=self.cos[:, :maybe_padded_num_tokens]
2567-
if self.cos is not None else None,
2568-
sin=self.sin[:, :maybe_padded_num_tokens]
2569-
if self.sin is not None else None):
2535+
weight_prefetch_method=self.weight_prefetch_method):
25702536
self.maybe_setup_kv_connector(scheduler_output)
25712537

25722538
hidden_states = self._generate_process_reqs_hidden_states(
@@ -3274,20 +3240,6 @@ def dummy_drafter_compute_logits(hidden_states):
32743240
return self.drafter.model.compute_logits(
32753241
hidden_states[dummy_indices])
32763242

3277-
# initialize rope
3278-
cos_sin_cache = self.model.model.layers[
3279-
self.model.model.
3280-
start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(
3281-
0, positions)
3282-
last_dim = cos_sin_cache.size()[-1]
3283-
cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat(
3284-
1, 1, 2).chunk(2, dim=-2)
3285-
# BSNH
3286-
self.cos[:, :num_tokens] = cos.view(1, -1, 1,
3287-
last_dim).contiguous()
3288-
self.sin[:, :num_tokens] = sin.view(1, -1, 1,
3289-
last_dim).contiguous()
3290-
32913243
with set_ascend_forward_context(
32923244
attn_metadata,
32933245
self.vllm_config,
@@ -3302,11 +3254,7 @@ def dummy_drafter_compute_logits(hidden_states):
33023254
batch_descriptor=batch_descriptor,
33033255
prefetch_stream=self.prefetch_stream,
33043256
model_instance=self.model,
3305-
weight_prefetch_method=self.weight_prefetch_method,
3306-
cos=self.cos[:, :num_tokens]
3307-
if self.cos is not None else None,
3308-
sin=self.sin[:, :num_tokens]
3309-
if self.sin is not None else None):
3257+
weight_prefetch_method=self.weight_prefetch_method):
33103258
hidden_states = self._generate_dummy_run_hidden_states(
33113259
with_prefill, input_ids, positions, attn_metadata,
33123260
num_tokens_padded, intermediate_tensors, inputs_embeds)

0 commit comments

Comments
 (0)