@@ -356,7 +356,7 @@ def get_split_computed_tokens(
356356 # else:
357357 # assert len(request_ids) == num_requests
358358 assert request_ids is not None and len (request_ids ) == num_requests
359- num_computed_tokens_of_cp_dcp = [[[0 ] * self .dcp_world_size
359+ num_computed_tokens_of_pcp_dcp_for_chunk = [[[0 ] * self .dcp_world_size
360360 for _ in range (self .pcp_world_size )]
361361 for _ in range (num_requests )]
362362 total_ranks = self .pcp_world_size * self .dcp_world_size
@@ -382,11 +382,11 @@ def get_split_computed_tokens(
382382 else :
383383 pcp_idx = start_rank // self .dcp_world_size
384384 dcp_idx = start_rank % self .dcp_world_size
385- num_computed_tokens_of_cp_dcp [req_idx ][pcp_idx ][
385+ num_computed_tokens_of_pcp_dcp_for_chunk [req_idx ][pcp_idx ][
386386 dcp_idx ] += consumed_tokens
387387 request_start_rank_dict [req_id ] = (start_rank ,
388388 tokens_blank )
389- return num_computed_tokens_of_cp_dcp
389+ return num_computed_tokens_of_pcp_dcp_for_chunk
390390
391391 virtual_size = total_ranks * cp_kv_cache_interleave_size
392392 base = int (total_tokens ) // virtual_size
@@ -397,7 +397,7 @@ def get_split_computed_tokens(
397397 for rank_idx in range (total_ranks ):
398398 pcp_idx = rank_idx // self .dcp_world_size
399399 dcp_idx = rank_idx % self .dcp_world_size
400- num_computed_tokens_of_cp_dcp [req_idx ][pcp_idx ][
400+ num_computed_tokens_of_pcp_dcp_for_chunk [req_idx ][pcp_idx ][
401401 dcp_idx ] = base * cp_kv_cache_interleave_size
402402
403403 # Distribute remainder tokens starting from start_rank
@@ -406,11 +406,11 @@ def get_split_computed_tokens(
406406 pcp_idx = rank // self .dcp_world_size
407407 dcp_idx = rank % self .dcp_world_size
408408 if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0 : # not last block or divisible
409- num_computed_tokens_of_cp_dcp [req_idx ][pcp_idx ][
409+ num_computed_tokens_of_pcp_dcp_for_chunk [req_idx ][pcp_idx ][
410410 dcp_idx ] += 1 * cp_kv_cache_interleave_size
411411 tokens_blank = 0
412412 else : # if last block and undivisible
413- num_computed_tokens_of_cp_dcp [req_idx ][pcp_idx ][
413+ num_computed_tokens_of_pcp_dcp_for_chunk [req_idx ][pcp_idx ][
414414 dcp_idx ] += remainder % cp_kv_cache_interleave_size
415415 tokens_blank = cp_kv_cache_interleave_size - (
416416 remainder % cp_kv_cache_interleave_size )
@@ -422,7 +422,7 @@ def get_split_computed_tokens(
422422 if request_start_rank_dict is not None :
423423 request_start_rank_dict [req_id ] = (start_rank , tokens_blank )
424424
425- return num_computed_tokens_of_cp_dcp
425+ return num_computed_tokens_of_pcp_dcp_for_chunk
426426
427427 def clear (self ) -> None :
428428 for block_table in self .block_tables :
0 commit comments