Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/CN/source/models/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl):
def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
o_tensor = torch.empty_like(q)
token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :],
infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0: self.tp_k_head_num_, :],
infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_loc,
Expand Down
4 changes: 2 additions & 2 deletions docs/EN/source/models/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl):
def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
o_tensor = torch.empty_like(q)
token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :],
infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0: self.tp_k_head_num_, :],
infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_loc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight):
return

def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])
destindex_copy_kv(buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_))
return

def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import FusedMoeWeightTP
from .fused_moe_weight_ep import FusedMoeWeightEP
from .parameter_weight import ParameterWeight, TpParameterWeight
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from typing import Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class ParameterWeight(BaseWeightTpl):
def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__()
self.weight_name = weight_name
self.bias_name = bias_name
self.data_type_ = data_type
self.weight = None
self.bias = None

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
load_ok = True
# Verify weight. The weight must be not None.
load_ok = load_ok and self.weight is not None
# Verify bias. If bias_name is set, it must be not None.
if self.bias_name is not None:
load_ok = load_ok and self.bias is not None
return load_ok


class TpParameterWeight(ParameterWeight):
def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None):
super().__init__(weight_name, data_type, bias_name)
self.split_n_embed = split_n_embed

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
start = self.split_n_embed * self.tp_rank_
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ def load_hf_weights(self, weights):
"""
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, MultiMMWeightTpl):
if isinstance(attr, TransformerLayerWeight):
attr.load_hf_weights(weights)
elif isinstance(attr, MultiMMWeightTpl):
with self.lock:
attr.load_hf_weights(weights)
elif isinstance(attr, BaseWeight):
attr.load_hf_weights(weights)

def verify_load(self):
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, TransformerLayerWeight):
attr.verify_load()
super().verify_load()
7 changes: 5 additions & 2 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,14 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
# 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。
self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")

def get_kv_buffer(self, layer_num: int):
return self.kv_buffer[layer_num]

def alloc_kv_move_buffer(self, max_req_total_len):
"""
pd 分离模式使用的特殊接口
"""
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
raise NotImplementedError("subclass need reimpl this method")
self.kv_move_buffer = torch.empty(
(1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
Expand All @@ -103,7 +106,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
raise NotImplementedError("subclass need reimpl this method")

num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def autotune(
as needed before invocation.
"""

def decorator(fn):
def decorator(fn: Callable) -> Callable:
return Autotuner(
fn=fn,
kernel_name=kernel_name,
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.models.qwen2.model import Qwen2TpPartModel
from lightllm.models.qwen3.model import Qwen3TpPartModel
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
from lightllm.models.qwen3next.model import Qwen3NextTpPartModel
from lightllm.models.chatglm2.model import ChatGlm2TpPartModel
from lightllm.models.internlm.model import InternlmTpPartModel
from lightllm.models.stablelm.model import StablelmTpPartModel
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/bloom/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _context_attention_kernel(
self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None
) -> torch.Tensor:
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
Expand All @@ -74,7 +74,7 @@ def _token_attention_kernel(
self, q, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None
) -> torch.Tensor:
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
token_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
Expand Down
32 changes: 16 additions & 16 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ def _decompress_kv(
):
if not skip_sample:
if is_fp8:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, :-2].view(torch.float8_e4m3fn)
kv_scale = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(torch.bfloat16)
k_scale = self.alloc_tensor([total_token_num, 1], dtype=kv_scale.dtype)
else:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
kv_scale = None
k_scale = None

Expand Down Expand Up @@ -549,7 +549,7 @@ def _context_attention_kernel_origin(
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
context_attention_fwd(
q_nope,
q_rope,
Expand Down Expand Up @@ -577,8 +577,8 @@ def _context_attention_kernel_origin_fp8(
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, :-2].view(torch.float8_e4m3fn)
kv_scale = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(torch.bfloat16)
context_attention_fwd_fp8(
q_nope,
q_rope,
Expand All @@ -601,7 +601,7 @@ def _token_gqa_decode_attention_flashattention(
):
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim)
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank)
k_descale, v_descale = None, None
Expand Down Expand Up @@ -631,7 +631,7 @@ def _token_gqa_decode_attention_flashinfer(
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)

kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)

infer_state.decode_wrapper.run(
Expand All @@ -649,7 +649,7 @@ def _token_gqa_decode_attention_flashdecoding(
):
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)
out = gqa_token_decode_attention_flash_decoding(
q_nope,
q_rope,
Expand All @@ -671,8 +671,8 @@ def _token_gqa_decode_attention_flashdecoding_fp8(
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)

kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, :-2].view(torch.float8_e4m3fn)
kv_scale = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(torch.bfloat16)
return gqa_token_decode_attention_flash_decoding_fp8(
q_nope,
q_rope,
Expand All @@ -693,8 +693,8 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
buffer[:, :, : self.kv_lora_rank],
buffer[:, :, self.kv_lora_rank :],
mem_index,
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :],
mem_manager.get_kv_buffer(self.layer_num_)[:, :, : self.kv_lora_rank],
mem_manager.get_kv_buffer(self.layer_num_)[:, :, self.kv_lora_rank :],
)
return

Expand All @@ -703,9 +703,9 @@ def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
buffer[:, :, : self.kv_lora_rank],
buffer[:, :, self.kv_lora_rank :],
mem_index,
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn),
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn),
mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype),
mem_manager.get_kv_buffer(self.layer_num_)[:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn),
mem_manager.get_kv_buffer(self.layer_num_)[:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn),
mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(buffer.dtype),
)
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def _context_sliding_attention_flashattention(
else:
window_size = (-1, -1)

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :].reshape(
-1, 1, self.tp_k_head_num_, self.head_dim_
)
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_)
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
Expand Down Expand Up @@ -112,10 +112,10 @@ def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttention
else:
window_size = (-1, -1)

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :].reshape(
-1, 1, self.tp_k_head_num_, self.head_dim_
)
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_)
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
Expand Down
Loading