Skip to content

Commit 9a6585a

Browse files
author
niushengxiao
committed
feat: add alloc_paged_token_indices function in req_manager
1 parent 6d227b3 commit 9a6585a

19 files changed

+317
-251
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ def _check_max_len_infer(self):
687687
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
688688
b_seq_len[:] = self.batch_max_tokens
689689
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
690-
mem_indexes = self.mem_manager.alloc(
691-
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True
690+
mem_indexes = self.req_manager.alloc_paged_token_indices(
691+
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len
692692
).cuda()
693693
total_token_num = self.batch_max_tokens
694694
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
@@ -759,12 +759,14 @@ def _autotune_warmup(self):
759759
0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen
760760
)
761761
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
762-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
763762
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
764763
b_seq_len[:] = input_len
765764
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
766765
total_token_num = input_len
767766
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
767+
mem_indexes = self.req_manager.alloc_paged_token_indices(
768+
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len
769+
).cuda()
768770
model_input = ModelInput(
769771
batch_size=1,
770772
total_token_num=total_token_num,

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def warmup(self, model):
202202
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
203203
b_seq_len.fill_(seq_len)
204204
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205-
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()
205+
mem_indexes = model.req_manager.alloc_paged_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
206206

207207
model_input = ModelInput(
208208
batch_size=batch_size,
@@ -258,7 +258,7 @@ def warmup_overlap(self, model):
258258
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
259259
b_seq_len.fill_(seq_len)
260260
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
261-
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()
261+
mem_indexes = model.req_manager.alloc_paged_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
262262

263263
micro_batch = ModelInput(
264264
is_prefill=False,

lightllm/common/deepseek2_page_size_variable_mem_manager.py renamed to lightllm/common/deepseek2_paged_mem_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import numpy as np
33
from .deepseek2_mem_manager import Deepseek2MemoryManager
4-
from .page_size_variable_mem_manager import PageSizeVariableMemoryManager
4+
from .paged_mem_manager import PagedMemoryManager
55
from lightllm.utils.log_utils import init_logger
66
from lightllm.utils.envs_utils import get_page_size
77

@@ -13,7 +13,7 @@ def cdiv(a, b):
1313
logger = init_logger(__name__)
1414

1515

16-
class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager):
16+
class Deepseek2PagedMemoryManager(PagedMemoryManager, Deepseek2MemoryManager):
1717
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
1818
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
1919

lightllm/common/mem_manager.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
5252
layer_num,
5353
)
5454
self.HOLD_TOKEN_MEMINDEX = self.size
55-
# MemoryManager也需要个引用备份,供内部使用
56-
self.req_to_token_indexs = None
5755

5856
def get_cell_size(self):
5957
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
@@ -245,9 +243,7 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
245243
def _free_buffers(self):
246244
self.kv_buffer = None
247245

248-
def alloc(
249-
self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False
250-
) -> torch.Tensor:
246+
def alloc(self, need_size) -> torch.Tensor:
251247
if need_size > self.mark_end - self.mark_start:
252248
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
253249
assert False, "error alloc state"
@@ -261,9 +257,6 @@ def alloc(
261257
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
262258
return ans
263259

264-
def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
265-
self.req_to_token_indexs[req_idx, start:end] = values
266-
267260
def free(self, free_index: Union[torch.Tensor, List[int]]):
268261
"""_summary_
269262
@@ -342,17 +335,8 @@ def __init__(self) -> None:
342335
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
343336
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
344337
]
345-
self.shared_tp_info_pages = [
346-
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}")
347-
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
348-
]
349338

350339
def get_unrefed_token_num(self, dp_rank_in_node: int):
351340
if self.is_multinode_tp:
352341
return self.shared_tp_infos[0].get_value()
353342
return self.shared_tp_infos[dp_rank_in_node].get_value()
354-
355-
def get_unrefed_page_num(self, dp_rank_in_node: int):
356-
if self.is_multinode_tp:
357-
return self.shared_tp_info_pages[0].get_value()
358-
return self.shared_tp_info_pages[dp_rank_in_node].get_value()

lightllm/common/mem_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
55
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
66
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
7-
from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager
7+
from lightllm.common.paged_mem_manager import PagedMemoryManager
88
from lightllm.utils.log_utils import init_logger
99

1010
logger = init_logger(__name__)
@@ -30,7 +30,7 @@ def select_mem_manager_class(mode):
3030
memory_manager_class = ExportCalibrationMemoryManager
3131
logger.info("Using mode export fp8kv calibration")
3232
elif "page_size_variable" in mode:
33-
memory_manager_class = PageSizeVariableMemoryManager
33+
memory_manager_class = PagedMemoryManager
3434
logger.info("Page size will be variable")
3535
else:
3636
memory_manager_class = MemoryManager

lightllm/common/page_size_variable_mem_manager.py

Lines changed: 0 additions & 184 deletions
This file was deleted.

0 commit comments

Comments
 (0)