Skip to content

Commit 21b5fe3

Browse files
bugfix && variable regularization
Signed-off-by: Apocalypse990923-qshi <[email protected]>
1 parent 38d2030 commit 21b5fe3

File tree

3 files changed

+9
-35
lines changed

3 files changed

+9
-35
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -906,31 +906,6 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
906906
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
907907
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
908908

909-
def reorder_by_req(self, num_tokens_per_req_per_rank):
910-
num_tokens_per_rank_per_req = [
911-
list(i) for i in zip(*num_tokens_per_req_per_rank)
912-
] # transpose to [rank, req]
913-
num_ranks = len(num_tokens_per_rank_per_req)
914-
num_reqs = len(num_tokens_per_rank_per_req[0])
915-
assert all(len(x) == num_reqs for x in num_tokens_per_rank_per_req)
916-
917-
# calc each rank's start offset
918-
offsets = []
919-
offset = 0
920-
for rank_tokens in num_tokens_per_rank_per_req:
921-
offsets.append(offset)
922-
offset += sum(rank_tokens)
923-
924-
reordered = []
925-
for req_idx in range(num_reqs):
926-
for rank_idx in range(num_ranks):
927-
start = offsets[rank_idx] + sum(
928-
num_tokens_per_rank_per_req[rank_idx][:req_idx])
929-
end = start + num_tokens_per_rank_per_req[rank_idx][req_idx]
930-
reordered.extend(range(start, end))
931-
932-
return torch.tensor(reordered, dtype=torch.int32)
933-
934909
def extract_req_dcp_by_chunk_cp(self, lst, chunk_idx, fill_value=0):
935910
num_reqs = len(lst)
936911

vllm_ascend/worker/block_table.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def get_split_computed_tokens(
356356
# else:
357357
# assert len(request_ids) == num_requests
358358
assert request_ids is not None and len(request_ids) == num_requests
359-
num_computed_tokens_of_cp_dcp = [[[0] * self.dcp_world_size
359+
num_computed_tokens_of_pcp_dcp_for_chunk = [[[0] * self.dcp_world_size
360360
for _ in range(self.pcp_world_size)]
361361
for _ in range(num_requests)]
362362
total_ranks = self.pcp_world_size * self.dcp_world_size
@@ -382,11 +382,11 @@ def get_split_computed_tokens(
382382
else:
383383
pcp_idx = start_rank // self.dcp_world_size
384384
dcp_idx = start_rank % self.dcp_world_size
385-
num_computed_tokens_of_cp_dcp[req_idx][pcp_idx][
385+
num_computed_tokens_of_pcp_dcp_for_chunk[req_idx][pcp_idx][
386386
dcp_idx] += consumed_tokens
387387
request_start_rank_dict[req_id] = (start_rank,
388388
tokens_blank)
389-
return num_computed_tokens_of_cp_dcp
389+
return num_computed_tokens_of_pcp_dcp_for_chunk
390390

391391
virtual_size = total_ranks * cp_kv_cache_interleave_size
392392
base = int(total_tokens) // virtual_size
@@ -397,7 +397,7 @@ def get_split_computed_tokens(
397397
for rank_idx in range(total_ranks):
398398
pcp_idx = rank_idx // self.dcp_world_size
399399
dcp_idx = rank_idx % self.dcp_world_size
400-
num_computed_tokens_of_cp_dcp[req_idx][pcp_idx][
400+
num_computed_tokens_of_pcp_dcp_for_chunk[req_idx][pcp_idx][
401401
dcp_idx] = base * cp_kv_cache_interleave_size
402402

403403
# Distribute remainder tokens starting from start_rank
@@ -406,11 +406,11 @@ def get_split_computed_tokens(
406406
pcp_idx = rank // self.dcp_world_size
407407
dcp_idx = rank % self.dcp_world_size
408408
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
409-
num_computed_tokens_of_cp_dcp[req_idx][pcp_idx][
409+
num_computed_tokens_of_pcp_dcp_for_chunk[req_idx][pcp_idx][
410410
dcp_idx] += 1 * cp_kv_cache_interleave_size
411411
tokens_blank = 0
412412
else: # if last block and undivisible
413-
num_computed_tokens_of_cp_dcp[req_idx][pcp_idx][
413+
num_computed_tokens_of_pcp_dcp_for_chunk[req_idx][pcp_idx][
414414
dcp_idx] += remainder % cp_kv_cache_interleave_size
415415
tokens_blank = cp_kv_cache_interleave_size - (
416416
remainder % cp_kv_cache_interleave_size)
@@ -422,7 +422,7 @@ def get_split_computed_tokens(
422422
if request_start_rank_dict is not None:
423423
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
424424

425-
return num_computed_tokens_of_cp_dcp
425+
return num_computed_tokens_of_pcp_dcp_for_chunk
426426

427427
def clear(self) -> None:
428428
for block_table in self.block_tables:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,6 @@ def _prepare_inputs(
14631463
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
14641464
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
14651465
Optional[torch.Tensor], Optional[torch.Tensor], int]:
1466-
# self.slot_mapping.fill_(0)
14671466
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
14681467
assert total_num_scheduled_tokens > 0
14691468
num_reqs = self.input_batch.num_reqs
@@ -1488,13 +1487,13 @@ def _prepare_inputs(
14881487
self.generate_kv_idx(tokens, scheduler_output)
14891488
self.input_batch.block_table.compute_slot_mapping(
14901489
req_indices, positions_np)
1490+
self.input_batch.block_table.commit_slot_mapping(
1491+
total_num_scheduled_tokens)
14911492
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
14921493
tokens)
14931494
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
14941495
# update total_num_scheduled_tokens
14951496
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
1496-
self.input_batch.block_table.commit_slot_mapping(
1497-
total_num_scheduled_tokens)
14981497

14991498
total_num_pcp_pads = sum(self.num_pcp_pads)
15001499
max_num_scheduled_tokens = max(tokens)

0 commit comments

Comments
 (0)