Skip to content

Commit 3581946

Browse files
authored
[Bugfix] fix eagle proposer (#4971)
### What this PR does / why we need it? After #4764, a lot of tensor created by `make_buffer` should be renamed, like `input_ids` -> `input_ids.gpu`. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: realliujiaxu <[email protected]>
1 parent 45889a6 commit 3581946

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def generate_token_ids(self,
169169
eagle_attn_metadata = attn_metadata[self.attn_layer_name]
170170
if spec_decode_metadata is None:
171171
# input_ids can be None for multimodal models.
172-
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
172+
target_token_ids = self.runner.input_ids.gpu[:num_scheduled_tokens]
173173
target_positions = positions[:num_scheduled_tokens]
174174
if self.name == SpecDcodeType.EAGLE3:
175175
target_hidden_states = torch.cat(
@@ -192,7 +192,7 @@ def generate_token_ids(self,
192192
)
193193
cu_num_tokens, token_indices =\
194194
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
195-
target_token_ids = self.runner.input_ids[token_indices]
195+
target_token_ids = self.runner.input_ids.gpu[token_indices]
196196
target_positions = positions[token_indices]
197197
if self.name == SpecDcodeType.EAGLE3:
198198
target_hidden_states = torch.cat(
@@ -245,7 +245,7 @@ def _get_eagle_atten_dict(
245245
num_scheduled_tokens)
246246

247247
# Get positions.
248-
positions_np = self.runner.positions_np[:total_num_scheduled_tokens]
248+
positions_np = self.runner.positions.np[:total_num_scheduled_tokens]
249249
np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices],
250250
arange,
251251
out=positions_np)
@@ -270,7 +270,7 @@ def _get_eagle_atten_dict(
270270
self.runner.input_batch.token_ids_cpu_tensor.flatten(),
271271
0,
272272
torch.from_numpy(token_indices),
273-
out=self.runner.input_ids_cpu[:total_num_scheduled_tokens])
273+
out=self.runner.input_ids.cpu[:total_num_scheduled_tokens])
274274

275275
# Prepare the attention metadata for each KV cache group and make layers
276276
# in the same group share the same metadata.
@@ -299,60 +299,61 @@ def _get_eagle_atten_dict(
299299
np.add(
300300
block_numbers * block_size,
301301
block_offsets,
302-
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
302+
out=block_table.slot_mapping.np[:total_num_scheduled_tokens])
303303

304304
# Prepare the attention metadata.
305-
self.runner.query_start_loc_np[0] = 0
306-
self.runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
305+
self.runner.query_start_loc.np[0] = 0
306+
self.runner.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
307307

308-
self.runner.seq_lens_np[:num_reqs] = (
308+
self.runner.seq_lens.np[:num_reqs] = (
309309
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] +
310310
num_scheduled_tokens)
311311

312312
# Copy the tensors to the NPU.
313-
self.runner.input_ids[:total_num_scheduled_tokens].copy_(
314-
self.runner.input_ids_cpu[:total_num_scheduled_tokens],
313+
self.runner.input_ids.gpu[:total_num_scheduled_tokens].copy_(
314+
self.runner.input_ids.cpu[:total_num_scheduled_tokens],
315315
non_blocking=True)
316316
if self.runner.uses_mrope:
317317
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
318-
self.runner.mrope_positions[:, :total_num_scheduled_tokens].copy_(
319-
self.runner.
320-
mrope_positions_cpu[:, :total_num_scheduled_tokens],
321-
non_blocking=True)
318+
self.runner.mrope_positions.gpu[:, :total_num_scheduled_tokens] \
319+
.copy_(
320+
self.runner.
321+
mrope_positions.cpu[:, :total_num_scheduled_tokens],
322+
non_blocking=True)
322323
else:
323324
# Common case (1D positions)
324-
self.runner.positions[:total_num_scheduled_tokens].copy_(
325-
self.runner.positions_cpu[:total_num_scheduled_tokens],
325+
self.runner.positions.gpu[:total_num_scheduled_tokens].copy_(
326+
self.runner.positions.cpu[:total_num_scheduled_tokens],
326327
non_blocking=True)
327328

328-
self.runner.query_start_loc[:num_reqs + 1].copy_(
329-
self.runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
330-
self.runner.seq_lens[:num_reqs].copy_(
331-
self.runner.seq_lens_cpu[:num_reqs], non_blocking=True)
329+
self.runner.query_start_loc.gpu[:num_reqs + 1].copy_(
330+
self.runner.query_start_loc.cpu[:num_reqs + 1], non_blocking=True)
331+
self.runner.seq_lens.gpu[:num_reqs].copy_(
332+
self.runner.seq_lens.cpu[:num_reqs], non_blocking=True)
332333

333334
# Fill unused with -1. Needed for reshape_and_cache
334-
self.runner.seq_lens[num_reqs:].fill_(0)
335-
self.runner.query_start_loc[num_reqs + 1:].fill_(-1)
335+
self.runner.seq_lens.gpu[num_reqs:].fill_(0)
336+
self.runner.query_start_loc.gpu[num_reqs + 1:].fill_(-1)
336337

337338
attn_metadata = {}
338339
# Prepare the attention metadata for each KV cache group and make layers
339340
# in the same group share the same metadata.
340341
for kv_cache_group_id, kv_cache_group_spec in enumerate(
341342
self.runner.kv_cache_config.kv_cache_groups):
342343
common_attn_metadata = AscendCommonAttentionMetadata(
343-
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
344-
query_start_loc_cpu=self.runner.query_start_loc_cpu[:num_reqs +
344+
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1],
345+
query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs +
345346
1],
346-
seq_lens_cpu=self.runner.seq_lens_cpu,
347+
seq_lens_cpu=self.runner.seq_lens.cpu,
347348
num_reqs=num_reqs,
348349
max_query_len=max_num_scheduled_tokens,
349350
num_actual_tokens=total_num_scheduled_tokens,
350351
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
351352
block_table_tensor=self.runner.input_batch.block_table[0].
352353
get_device_tensor(),
353354
slot_mapping=self.runner.input_batch.block_table[0].
354-
slot_mapping,
355-
positions=self.runner.positions,
355+
slot_mapping.gpu,
356+
positions=self.runner.positions.gpu,
356357
attn_mask=self.runner.attn_mask,
357358
spec_attn_mask=self.runner.spec_attn_mask,
358359
attn_state=self.runner.attn_state,

0 commit comments

Comments
 (0)