Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir

RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly

# TODO: offline compile
# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl && \
pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl

RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel

Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.deepep
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir

RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly

# TODO: offline compile
# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl && \
pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl

RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms
RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev
Expand Down
8 changes: 4 additions & 4 deletions docs/CN/source/getting_started/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ ShareGPT 数据集测试 (benchmark_sharegpt.py)
python test/benchmark/service/benchmark_sharegpt.py \
--dataset /path/to/sharegpt_dataset.json \
--tokenizer /path/to/tokenizer \
--num_prompts 1000 \
--request_rate 10.0
--num-prompts 1000 \
--request-rate 10.0

**主要参数:**

- ``--dataset``: ShareGPT 格式数据集路径
- ``--tokenizer``: 分词器路径
- ``--num_prompts``: 测试提示数量
- ``--request_rate``: 请求速率 (requests/s)
- ``--num-prompts``: 测试提示数量
- ``--request-rate``: 请求速率 (requests/s)


Prompt Cache 测试
Expand Down
8 changes: 4 additions & 4 deletions docs/EN/source/getting_started/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ Performance testing using ShareGPT real conversation data.
python test/benchmark/service/benchmark_sharegpt.py \
--dataset /path/to/sharegpt_dataset.json \
--tokenizer /path/to/tokenizer \
--num_prompts 1000 \
--request_rate 10.0
--num-prompts 1000 \
--request-rate 10.0

**Main Parameters:**

- ``--dataset``: ShareGPT format dataset path
- ``--tokenizer``: Tokenizer path
- ``--num_prompts``: Number of test prompts
- ``--request_rate``: Request rate (requests/s)
- ``--num-prompts``: Number of test prompts
- ``--request-rate``: Request rate (requests/s)

Prompt Cache Testing
~~~~~~~~~~~~~~~~~~~
Expand Down
103 changes: 103 additions & 0 deletions lightllm/common/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from typing import List, Optional, Tuple, Union
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def get_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


try:
import flash_attn_3._C # Registers operators with PyTorch

flash_attn_3_mtp = torch.ops.flash_attn_3

def flash_attn_with_kvcache_mtp(
q,
k,
v,
k_new: Optional[torch.Tensor] = None,
v_new: Optional[torch.Tensor] = None,
q_v: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
page_table: Optional[torch.Tensor] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
is_causal=False,
window_size=(-1, -1),
softcap=0.0, # 0.0 means deactivated
is_rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0,
pack_gqa=None,
sm_margin=0,
mtp_step=0,
):
assert k.stride(-1) == 1, "k must have contiguous last dimension"
assert v.stride(-1) == 1, "v must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (q_v.shape[-1] if q_v is not None else 0)) ** (-0.5)
seqused_k = get_contiguous(seqused_k)

q, k, k_new, v_new = [get_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k_new = [get_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)]
page_table = get_contiguous(page_table)
out, softmax_lse, *rest = flash_attn_3_mtp.fwd(
q,
k,
v,
k_new,
v_new,
q_v,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
seqused_k,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale,
k_descale,
v_descale,
softmax_scale,
is_causal,
window_size[0],
window_size[1],
0,
softcap,
is_rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
mtp_step,
)
return out

except:
flash_attn_3_mtp = None
flash_attn_with_kvcache_mtp = None
logger.warning("flash_attn_3._C is not available, please install flash-attention-3 package.")
38 changes: 37 additions & 1 deletion lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightllm.utils.dist_utils import get_global_world_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp

logger = init_logger(__name__)

Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
self.num_heads = network_config["num_attention_heads"]
self.num_kv_heads = network_config["num_key_value_heads"]
self.mtp_step = get_env_start_args().mtp_step
self.mtp_size = self.mtp_step + 1
return

def _bind_func(self):
Expand All @@ -95,7 +98,11 @@ def _bind_attention(self):
)
else:
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
if get_env_start_args().enable_fa3:
if get_env_start_args().enable_fa3_mtp:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_env_start_args() is called multiple times within this if/elif chain (lines 103, 107, 111). To improve performance and readability, consider calling it once before this conditional block and storing the result in a local variable.

self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_mtp, self
)
elif get_env_start_args().enable_fa3:
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self
)
Expand Down Expand Up @@ -559,6 +566,35 @@ def _context_attention_kernel_origin_fp8(
)
return o_tensor

def _token_gqa_decode_attention_mtp(
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim)
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank)
k_descale, v_descale = None, None
o_tensor = flash_attn_with_kvcache_mtp(
q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim),
k=k_rope,
v=kv_nope,
q_v=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank),
page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size],
seqused_k=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(),
cu_seqlens_q=infer_state.cu_seqlens_q,
cu_seqlens_k_new=infer_state.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=self.softmax_scale,
is_causal=True,
window_size=(-1, -1),
softcap=0.0,
k_descale=k_descale,
v_descale=v_descale,
mtp_step=self.mtp_step,
)
return o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank)

def _token_gqa_decode_attention_flashattention(
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, kvargs):
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
if get_env_start_args().enable_fa3 or get_env_start_args().enable_fa3_mtp:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_env_start_args() is called twice on this line. To improve performance and readability, consider calling it once before the if statement and storing the result in a local variable.

self.infer_state_class = Deepseek2FlashAttentionStateInfo
elif self.enable_flashinfer:
self.infer_state_class = Deepseek2FlashInferStateInfo
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
but ensure that the model is compatible with the specified step count.
currently, deepseekv3 model only support 1 step""",
)
parser.add_argument(
"--enable_fa3_mtp",
action="store_true",
help="""inference backend will use the fa3_mtp kernel for decode with MTP mode""",
)
parser.add_argument(
"--kv_quant_calibration_config_path",
type=str,
Expand Down
7 changes: 7 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .router.manager import start_router_process
from lightllm.utils.process_check import is_process_active
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp
from lightllm.utils.shm_size_check import check_recommended_shm_size

logger = init_logger(__name__)
Expand Down Expand Up @@ -139,6 +140,12 @@ def normal_or_p_d_start(args):
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0

if args.enable_fa3_mtp:
assert args.mtp_mode is not None, "enable_fa3_mtp must set mtp_mode"
assert (
flash_attn_with_kvcache_mtp is not None
), "flash_attn_with_kvcache_mtp is None, please check if you have installed the fa3_mtp kernel"

# 检查GPU数量是否足够
if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class StartArgs:
mtp_mode: Optional[str] = field(default=None)
mtp_draft_model_dir: Optional[str] = field(default=None)
mtp_step: int = field(default=0)
enable_fa3_mtp: bool = field(default=False)
kv_quant_calibration_config_path: Optional[str] = field(default=None)
nixl_pd_kv_page_num: int = field(default=16)
nixl_pd_kv_page_size: int = field(default=1024)
Expand Down
Loading