diff --git a/docs/CN/source/models/add_new_model.md b/docs/CN/source/models/add_new_model.md index 49b47ffa2..275f8896c 100755 --- a/docs/CN/source/models/add_new_model.md +++ b/docs/CN/source/models/add_new_model.md @@ -454,8 +454,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl): def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: o_tensor = torch.empty_like(q) token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0: self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), layer_weight.tp_alibi, infer_state.b_loc, diff --git a/docs/EN/source/models/add_new_model.md b/docs/EN/source/models/add_new_model.md index 6127dffaf..b9e92cdf3 100755 --- a/docs/EN/source/models/add_new_model.md +++ b/docs/EN/source/models/add_new_model.md @@ -454,8 +454,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl): def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: o_tensor = torch.empty_like(q) token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0: self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), layer_weight.tp_alibi, infer_state.b_loc, diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 7567bc644..4cb86c09c 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -42,7 +42,7 @@ def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight): return def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + destindex_copy_kv(buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_)) return def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 9db819dfd..15d5aff5c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -10,3 +10,4 @@ from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight_tp import FusedMoeWeightTP from .fused_moe_weight_ep import FusedMoeWeightEP +from .parameter_weight import ParameterWeight, TpParameterWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py new file mode 100644 index 000000000..65adcd469 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -0,0 +1,44 @@ +import torch +from typing import Dict +from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id + + +class ParameterWeight(BaseWeightTpl): + def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__() + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self.weight = None + self.bias = None + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name in weights: + self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights: + self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + + def verify_load(self): + load_ok = True + # Verify weight. The weight must be not None. + load_ok = load_ok and self.weight is not None + # Verify bias. If bias_name is set, it must be not None. + if self.bias_name is not None: + load_ok = load_ok and self.bias is not None + return load_ok + + +class TpParameterWeight(ParameterWeight): + def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None): + super().__init__(weight_name, data_type, bias_name) + self.split_n_embed = split_n_embed + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) + + if self.weight_name in weights: + self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights: + self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 48167a067..0c6a7b9be 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,8 +36,17 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MultiMMWeightTpl): + if isinstance(attr, TransformerLayerWeight): + attr.load_hf_weights(weights) + elif isinstance(attr, MultiMMWeightTpl): with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def verify_load(self): + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, TransformerLayerWeight): + attr.verify_load() + super().verify_load() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 57ae9838b..718573272 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -89,11 +89,14 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + def get_kv_buffer(self, layer_num: int): + return self.kv_buffer[layer_num] + def alloc_kv_move_buffer(self, max_req_total_len): """ pd 分离模式使用的特殊接口 """ - if isinstance(self, MemoryManager) and type(self) != MemoryManager: + if isinstance(self, MemoryManager) and type(self) is not MemoryManager: raise NotImplementedError("subclass need reimpl this method") self.kv_move_buffer = torch.empty( (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" @@ -103,7 +106,7 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: - if isinstance(self, MemoryManager) and type(self) != MemoryManager: + if isinstance(self, MemoryManager) and type(self) is not MemoryManager: raise NotImplementedError("subclass need reimpl this method") num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..c62a2572f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -62,7 +62,7 @@ def autotune( as needed before invocation. """ - def decorator(fn): + def decorator(fn: Callable) -> Callable: return Autotuner( fn=fn, kernel_name=kernel_name, diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 96329eabe..03b4fa352 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -8,6 +8,7 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel from lightllm.models.chatglm2.model import ChatGlm2TpPartModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 8299697f3..d93f144b0 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -54,7 +54,7 @@ def _context_attention_kernel( self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) context_attention_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), kv[:, 0 : self.tp_k_head_num_, :], @@ -74,7 +74,7 @@ def _token_attention_kernel( self, q, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) token_attention_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), kv[:, 0 : self.tp_k_head_num_, :], diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 30d37d1df..1758a9d1a 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -306,11 +306,11 @@ def _decompress_kv( ): if not skip_sample: if is_fp8: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, :-2].view(torch.float8_e4m3fn) + kv_scale = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(torch.bfloat16) k_scale = self.alloc_tensor([total_token_num, 1], dtype=kv_scale.dtype) else: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) kv_scale = None k_scale = None @@ -549,7 +549,7 @@ def _context_attention_kernel_origin( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) context_attention_fwd( q_nope, q_rope, @@ -577,8 +577,8 @@ def _context_attention_kernel_origin_fp8( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, :-2].view(torch.float8_e4m3fn) + kv_scale = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(torch.bfloat16) context_attention_fwd_fp8( q_nope, q_rope, @@ -601,7 +601,7 @@ def _token_gqa_decode_attention_flashattention( ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) k_descale, v_descale = None, None @@ -631,7 +631,7 @@ def _token_gqa_decode_attention_flashinfer( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) infer_state.decode_wrapper.run( @@ -649,7 +649,7 @@ def _token_gqa_decode_attention_flashdecoding( ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) out = gqa_token_decode_attention_flash_decoding( q_nope, q_rope, @@ -671,8 +671,8 @@ def _token_gqa_decode_attention_flashdecoding_fp8( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, :-2].view(torch.float8_e4m3fn) + kv_scale = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(torch.bfloat16) return gqa_token_decode_attention_flash_decoding_fp8( q_nope, q_rope, @@ -693,8 +693,8 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): buffer[:, :, : self.kv_lora_rank], buffer[:, :, self.kv_lora_rank :], mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank], - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :], + mem_manager.get_kv_buffer(self.layer_num_)[:, :, : self.kv_lora_rank], + mem_manager.get_kv_buffer(self.layer_num_)[:, :, self.kv_lora_rank :], ) return @@ -703,9 +703,9 @@ def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager): buffer[:, :, : self.kv_lora_rank], buffer[:, :, self.kv_lora_rank :], mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype), + mem_manager.get_kv_buffer(self.layer_num_)[:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn), + mem_manager.get_kv_buffer(self.layer_num_)[:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn), + mem_manager.get_kv_buffer(self.layer_num_)[:, :, -2:].view(buffer.dtype), ) return diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 1246af090..73bc2c25d 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -76,10 +76,10 @@ def _context_sliding_attention_flashattention( else: window_size = (-1, -1) - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) @@ -112,10 +112,10 @@ def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttention else: window_size = (-1, -1) - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bb38c45bb..bbb96ef67 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -238,7 +238,7 @@ def _context_attention_flashinfer_kernel_fp8( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) kv = kv.unsqueeze(1) k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) @@ -258,7 +258,7 @@ def _context_attention_flashinfer_kernel( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) kv = kv.unsqueeze(1) infer_state.prefill_wrapper.run( q.view(q.shape[0], -1, self.head_dim_), @@ -271,7 +271,7 @@ def _context_attention_kernel( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) context_attention_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), kv[:, 0 : self.tp_k_head_num_, :], @@ -291,7 +291,7 @@ def _context_attention_kernel_ppl_int8kv( ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out batch_size = infer_state.b_seq_len.shape[0] - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) kv_scale = infer_state.mem_manager.scale_buffer[self.layer_num_] max_seq_len = infer_state.max_seq_len kv_dequant = self.alloc_tensor( @@ -319,10 +319,10 @@ def _context_attention_kernel_ppl_int8kv( return o_tensor def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) @@ -359,13 +359,13 @@ def _context_attention_flashattention_fp8( infer_state.token_batch_ids, ) cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) + (infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, : self.tp_k_head_num_, :]) .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) .view(torch.float8_e4m3fn) ) cache_v = ( ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] ) @@ -480,17 +480,17 @@ def _tpsp_ffn( # return ffn2_out def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + destindex_copy_kv(buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_)) return def _copy_kv_to_mem_cache_with_calibration(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + destindex_copy_kv(buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_)) mem_manager.update_calibration_data(buffer, self.layer_num_) return def _copy_kv_to_mem_cache_int8kv(self, buffer, mem_index, mem_manager): destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] + buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_), mem_manager.scale_buffer[self.layer_num_] ) return @@ -500,7 +500,7 @@ def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): buffer, mem_index, scales[self.layer_num_] if scales is not None else None, - mem_manager.kv_buffer[self.layer_num_].view(torch.float8_e4m3fn), + mem_manager.get_kv_buffer(self.layer_num_).view(torch.float8_e4m3fn), ) return @@ -508,7 +508,7 @@ def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager): from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_quantize_kv destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] + buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_), mem_manager.scale_buffer[self.layer_num_] ) return @@ -516,7 +516,7 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): from lightllm.models.llama.triton_kernel.ppl_int4kv_copy_kv import destindex_copy_int4kv destindex_copy_int4kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] + buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_), mem_manager.scale_buffer[self.layer_num_] ) return @@ -525,7 +525,7 @@ def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInfer calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_).unsqueeze(1) k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) offline_scales = infer_state.mem_manager.scales_list @@ -545,7 +545,7 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_).unsqueeze(1) infer_state.decode_wrapper.run( q.view(calcu_shape1), (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), @@ -562,7 +562,7 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, la token_att_fwd( q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :], att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -578,7 +578,7 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, la token_softmax_reducev_fwd( att_m_tensor, - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ], o_tensor.view(calcu_shape1), @@ -598,8 +598,8 @@ def _token_decode_gqa_attention_normal(self, q, infer_state: LlamaInferStateInfo o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out gqa_decode_attention_fwd( q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ], o_tensor.view(calcu_shape1), @@ -616,7 +616,7 @@ def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, la att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), q.dtype) token_att_fwd_int8k( q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :], infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], att_m_tensor, infer_state.req_manager.req_to_token_indexs, @@ -635,7 +635,7 @@ def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, la o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out token_att_fwd2_int8v( prob, - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ], infer_state.mem_manager.scale_buffer[self.layer_num_][ @@ -654,8 +654,8 @@ def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, la def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( @@ -673,8 +673,8 @@ def _token_decode_attention_gqa_flashdecoding(self, q, infer_state: LlamaInferSt # 对 gqa 模型进行推理优化的代码 from ..triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return gqa_token_decode_attention_flash_decoding( @@ -698,9 +698,9 @@ def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo light_ops.group8_int8kv_decode_attention( o_tensor.view(calcu_shape1), q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :], infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ], infer_state.mem_manager.scale_buffer[self.layer_num_][ @@ -726,8 +726,8 @@ def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, o_tensor.view(calcu_shape1), 1.0 / (self.head_dim_ ** 0.5), q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ], infer_state.req_manager.req_to_token_indexs, @@ -743,8 +743,8 @@ def _token_decode_attention_ppl_fp16_flashdecoding( ): from lightllm.models.llama.triton_kernel.ppl_fp16_flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( @@ -763,9 +763,9 @@ def _token_decode_attention_ppl_int8kv_flashdecoding( ): from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ @@ -789,9 +789,9 @@ def _token_decode_attention_ppl_int4kv_flashdecoding( ): from lightllm.models.llama.triton_kernel.ppl_int4kv_flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ @@ -817,8 +817,8 @@ def _token_decode_attention_gqa_flashdecoding_vsm( gqa_token_decode_attention_flash_decoding_vsm, ) - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] q_shape = (infer_state.batch_size, self.tp_q_head_num_, self.head_dim_) @@ -832,10 +832,10 @@ def _token_decode_attention_gqa_flashdecoding_vsm( ) def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) @@ -865,13 +865,13 @@ def _token_decode_attention_flashattention_fp8( self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None ): cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) + (infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, : self.tp_k_head_num_, :]) .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) .view(torch.float8_e4m3fn) ) cache_v = ( ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ + infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] ) diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index ce27e3ee5..258b6ad85 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -41,14 +41,14 @@ def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Ph return q, cache_kv def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + destindex_copy_kv(buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_)) return def _context_attention_kernel( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + kv = infer_state.mem_manager.get_kv_buffer(self.layer_num_) context_attention_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), kv[:, 0 : self.tp_k_head_num_, :], @@ -66,8 +66,8 @@ def _context_attention_kernel( def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + cache_k = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.get_kv_buffer(self.layer_num_)[ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index e3d8de461..2add77f38 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -17,7 +17,7 @@ def __init__(self, kvargs): def _init_config(self): super()._init_config() - if self.config["sliding_window"] is None: + if "sliding_window" not in self.config or self.config["sliding_window"] is None: self.config["sliding_window"] = self.max_total_token_num # rename key [SYM: to be confirmed] return diff --git a/lightllm/models/qwen3next/__init__.py b/lightllm/models/qwen3next/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3next/layer_infer/gdn_layer_infer.py b/lightllm/models/qwen3next/layer_infer/gdn_layer_infer.py new file mode 100644 index 000000000..d698975dc --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/gdn_layer_infer.py @@ -0,0 +1,184 @@ +from typing_extensions import Self +import torch +import torch.distributed as dist +from einops import rearrange + +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops.chunk import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import fused_recurrent_gated_delta_rule +from lightllm.distributed import all_reduce +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer +from lightllm.models.qwen3next.layer_weights.gdn_layer_weight import Qwen3NextGatedDeltaNetWeight + + +class Qwen3NextGatedDeltaNetInfer(BaseLayerInfer): + def __init__(self, layer_idx, network_config): + super().__init__() + self.network_config_ = network_config + self.layer_idx_ = layer_idx + self.hidden_size = self.network_config_["hidden_size"] + self.num_v_heads = self.network_config_["linear_num_value_heads"] + self.num_k_heads = self.network_config_["linear_num_key_heads"] + self.head_k_dim = self.network_config_["linear_key_head_dim"] + self.head_v_dim = self.network_config_["linear_value_head_dim"] + self.eps_ = self.network_config_["rms_norm_eps"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = self.network_config_["linear_conv_kernel_dim"] + self.activation = self.network_config_["hidden_act"] + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + def _fix_query_key_value_ba_ordering(self, mixed_qkvzba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + mixed_qkvz, mixed_ba = torch.split(mixed_qkvzba, [self.tp_qkvz_dim, self.tp_ba_dim], dim=-1) + + mixed_qkvz = mixed_qkvz.view( + -1, + self.tp_num_k_heads, + self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads_per_k_head, + ) + mixed_ba = mixed_ba.view(-1, self.tp_num_k_heads, 2 * self.num_v_heads_per_k_head) + + qkvz_split_list = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads_per_k_head * self.head_v_dim), + (self.num_v_heads_per_k_head * self.head_v_dim), + ] + (query, key, value, z) = torch.split(mixed_qkvz, qkvz_split_list, dim=2) + (b, a) = torch.split(mixed_ba, [self.num_v_heads_per_k_head, self.num_v_heads_per_k_head], dim=2) + + query = query.reshape(-1, self.tp_num_k_heads * self.head_k_dim) + key = key.reshape(-1, self.tp_num_k_heads * self.head_k_dim) + value = value.reshape(-1, self.tp_num_v_heads * self.head_v_dim) + z = z.reshape(-1, self.tp_num_v_heads, self.head_v_dim) + b = b.reshape(-1, self.tp_num_v_heads) + a = a.reshape(-1, self.tp_num_v_heads) + + return query, key, value, z, b, a + + def _rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + query, key = map(lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key)) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query, key, value + + def forward( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetWeight, + ): + assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager) + input = input.view(-1, self.hidden_size) + + conv_states, ssm_states = infer_state.mem_manager.get_mamba_state_buffer(self.layer_idx_) + cache_indices = infer_state.b_req_idx + + mixed_qkvzba = layer_weight.in_proj.mm(input) + q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba) + mixed_qkv = torch.cat([q, k, v], dim=-1) + + if infer_state.is_prefill: + mixed_qkv = mixed_qkv.transpose(0, 1) + out_tensor = self.alloc_tensor(mixed_qkv.shape, mixed_qkv.dtype, device=mixed_qkv.device) + causal_conv1d_fn( + mixed_qkv, + layer_weight.conv1d.weight.transpose(0, 1), + layer_weight.conv1d.bias, + conv_states.transpose(1, 2), + infer_state.b1_cu_q_seq_len, + out=out_tensor, + cache_indices=cache_indices, + activation=self.activation, + ) + mixed_qkv = out_tensor.transpose(0, 1) + else: + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states.transpose(1, 2), + layer_weight.conv1d.weight.transpose(0, 1), + layer_weight.conv1d.bias, + self.activation, + conv_state_indices=cache_indices, + validate_data=True, + ) + + # Rearrange mixed_qkv to query, key, value + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + + # Compute beta and g + beta = b.sigmoid() + g = fused_gdn_gating(layer_weight.A_log.weight, a, layer_weight.dt_bias.weight) + g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) + + if infer_state.is_prefill: + initial_state = ssm_states[cache_indices].contiguous() + (core_attn_out, last_recurrent_state,) = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=infer_state.b1_cu_q_seq_len, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Update SSM state with final state + ssm_states[cache_indices, ...] = last_recurrent_state.to(ssm_states.dtype) + else: + batch_size = input.shape[0] + cu_seqlens = torch.arange(0, batch_size + 1, dtype=torch.int32, device=input.device) + (core_attn_out, last_recurrent_state,) = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=ssm_states, + inplace_final_state=True, + cu_seqlens=cu_seqlens, + ssm_state_indices=cache_indices, + use_qk_l2norm_in_kernel=True, + ) + + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.norm.weight, + layer_weight.norm.bias, + self.eps_, + z, + out=norm_out, + ) + core_attn_out = norm_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + + output = layer_weight.out_proj.mm(core_attn_out) + if self.tp_world_size_ > 1: + all_reduce(output, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) + return output diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..73ea97457 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py @@ -0,0 +1,16 @@ +import os +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np + +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): + def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..e2746cd8e --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -0,0 +1,174 @@ +import torch +import torch.nn.functional as F +import torch.distributed as dist +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import Qwen3NextTransformerLayerWeight +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from functools import partial +from lightllm.utils.log_utils import init_logger +from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from typing import Tuple +from typing_extensions import override +from einops import rearrange +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops.chunk import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import fused_recurrent_gated_delta_rule +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward +from lightllm.models.qwen3next.layer_infer.gdn_layer_infer import Qwen3NextGatedDeltaNetInfer + +logger = init_logger(__name__) + + +class Qwen3NextTransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + self.is_gdn = (layer_num + 1) % network_config["full_attention_interval"] != 0 + self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) + + if self.is_gdn: + self.gdn_infer = Qwen3NextGatedDeltaNetInfer(layer_num, network_config) + return + + @override + def _bind_norm(self): + pass + + def _ffn_with_shared_expert( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out + moe_out = self._ffn(input, infer_state, layer_weight) + return shared_expert_out + moe_out + + @override + def _att_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) + return out + + @override + def _ffn_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) + return out + + @override + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) + cache_kv = layer_weight.kv_proj.mm( + input.view(-1, self.embed_dim_), + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + cache_kv[:, : self.tp_k_head_num_, :] = gemma_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), + layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv + + @override + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + o_tensor = layer_weight.o_proj.mm(input) + return o_tensor + + def _context_full_attn( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input)) + q, cache_kv = self._get_qkv(input, infer_state, layer_weight) + input = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = o * gate + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def context_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + if self.is_gdn: + o = self.gdn_infer.forward(input1, infer_state, layer_weight.gdn_layer_weight) + else: + o = self._context_full_attn(input1, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def _token_full_attn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight): + gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input)) + q, cache_kv = self._get_qkv(input, infer_state, layer_weight) + input = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = o * gate + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + if self.is_gdn: + o = self.gdn_infer.forward(input1, infer_state, layer_weight.gdn_layer_weight) + else: + o = self._token_full_attn(input1, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings diff --git a/lightllm/models/qwen3next/layer_weights/gdn_layer_weight.py b/lightllm/models/qwen3next/layer_weights/gdn_layer_weight.py new file mode 100644 index 000000000..a9df9dc89 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/gdn_layer_weight.py @@ -0,0 +1,91 @@ +from typing_extensions import override +import torch + +from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + MultiROWMMWeight, + COLMMWeight, + NormWeight, + ROWBMMWeight, + TpParameterWeight, +) + + +class Qwen3NextGatedDeltaNetWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def load_hf_weights(self, weights): + if self.conv1d.weight_name in weights: + weights[self.conv1d.weight_name] = self._parse_conv1d(weights[self.conv1d.weight_name].squeeze(1)) + if self.conv1d.bias_name in weights: + weights[self.conv1d.bias_name] = self._parse_conv1d(weights[self.conv1d.bias_name]) + super().load_hf_weights(weights) + + @override + def _parse_config(self): + self.num_v_heads = self.network_config_["linear_num_value_heads"] + self.num_k_heads = self.network_config_["linear_num_key_heads"] + self.k_head_dim = self.network_config_["linear_key_head_dim"] + self.v_head_dim = self.network_config_["linear_value_head_dim"] + + @override + def _init_weight(self): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + self.conv1d = ROWMMWeight( + weight_name=f"{prefix}.conv1d.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="conv1d_weight", + ) + + self.in_proj = MultiROWMMWeight( + weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="in_proj_weight", + ) + + self.out_proj = COLMMWeight( + weight_name=f"{prefix}.out_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="out_proj_weight", + ) + + self.dt_bias = TpParameterWeight( + weight_name=f"{prefix}.dt_bias", + data_type=torch.float32, + split_n_embed=self.num_v_heads // self.tp_world_size_, + ) + + self.A_log = TpParameterWeight( + weight_name=f"{prefix}.A_log", + data_type=torch.float32, + split_n_embed=self.num_v_heads // self.tp_world_size_, + ) + + self.norm = NormWeight( + weight_name=f"{prefix}.norm.weight", + data_type=self.data_type_, + ) + + def _parse_conv1d(self, weight): + qk_dim = self.num_k_heads * self.k_head_dim + v_dim = self.num_v_heads * self.v_head_dim + + q_bias, k_bias, v_bias = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q_bias.chunk(self.tp_world_size_, dim=0) + k_splits = k_bias.chunk(self.tp_world_size_, dim=0) + v_splits = v_bias.chunk(self.tp_world_size_, dim=0) + + new_weight = torch.cat( + [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 + ) + + return new_weight diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..03e3f894e --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -0,0 +1,113 @@ +import os +import torch +import math +import numpy as np +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.utils.envs_utils import enable_env_vars +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + MultiROWMMWeight, + COLMMWeight, + NormWeight, + FusedMoeWeightTP, + FusedMoeWeightEP, + ROWBMMWeight, +) +from functools import partial +from typing_extensions import override +from lightllm.common.basemodel.layer_weights.meta_weights import TpParameterWeight +from lightllm.models.qwen3next.layer_weights.gdn_layer_weight import Qwen3NextGatedDeltaNetWeight + + +class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + @override + def _parse_config(self): + super()._parse_config() + self.full_attention_interval = self.network_config_["full_attention_interval"] + self.is_gdn = (self.layer_num_ + 1) % self.full_attention_interval != 0 + return + + @override + def _init_weight(self): + self._init_moe() + self._init_shared_expert_weight() + + self.att_norm_weight_ = NormWeight( + self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + ) + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) + + if self.is_gdn: + self.gdn_layer_weight = Qwen3NextGatedDeltaNetWeight( + self.layer_num_, self.data_type_, self.network_config_, self.mode, self.quant_cfg + ) + else: + self._init_qkv() + self._init_o() + self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self.o_gate_proj = ROWMMWeight( + weight_name=f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight", + data_type=self.data_type_, + bias_name=f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.bias", + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="o_gate_proj", + ) + return + + @override + def load_hf_weights(self, weights): + if not self.is_gdn: + if self.q_proj.weight_name in weights: + weight = weights[self.q_proj.weight_name] + num_heads = self.tp_q_head_num_ * self.tp_world_size_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self.q_proj.weight_name] = _q_proj + weights[self.o_gate_proj.weight_name] = _gate_proj + if self.q_proj.bias_name in weights: + bias = weights[self.q_proj.bias_name] + num_heads = self.tp_q_head_num_ * self.tp_world_size_ + bias = bias.view(num_heads * 2, self.head_dim) + _q_proj = bias[0::2].reshape(-1) + _gate_proj = bias[1::2].reshape(-1) + weights[self.q_proj.bias_name] = _q_proj + weights[self.o_gate_proj.bias_name] = _gate_proj + + super().load_hf_weights(weights) + + def _init_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + self.shared_expert_gate_up_proj = MultiROWMMWeight( + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate_up_proj", + ) + self.shared_expert_down_proj = COLMMWeight( + weight_name=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_down_proj", + ) + self.shared_expert_gate = ROWMMWeight( + weight_name=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_name=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate", + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py new file mode 100644 index 000000000..d57324902 --- /dev/null +++ b/lightllm/models/qwen3next/mem_manager.py @@ -0,0 +1,98 @@ +import torch +from typing import Dict, List, Protocol, Set, Union, Tuple +from typing_extensions import override +from lightllm.utils.log_utils import init_logger +from lightllm.common.mem_manager import MemoryManager +from lightllm.utils.envs_utils import get_env_start_args + +logger = init_logger(__name__) + + +class MambaStateBufferConfig: + def __init__( + self, + conv_state_dtype: torch.dtype, + conv_state_shape: torch.Size, + ssm_state_dtype: torch.dtype, + ssm_state_shape: torch.Size, + ): + + self.conv_state_dtype = conv_state_dtype + self.conv_state_shape = conv_state_shape + self.ssm_state_dtype = ssm_state_dtype + self.ssm_state_shape = ssm_state_shape + + +class Qwen3NextMemoryManager(MemoryManager): + def __init__( + self, + size, + dtype, + num_kv_heads, + head_dim, + layer_num, + full_attention_interval: int, + max_req_num: int, + mamba_state_buffer_config: MambaStateBufferConfig, + always_copy=False, + mem_fraction=0.9, + ): + self.full_attention_interval = full_attention_interval + self.max_req_num = max_req_num + + assert layer_num % full_attention_interval == 0 + self.layer_num_wo_mtp = layer_num + self.full_attn_layer_num = layer_num // full_attention_interval + self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + + self.conv_state_dtype = mamba_state_buffer_config.conv_state_dtype + self.conv_state_shape = mamba_state_buffer_config.conv_state_shape + self.ssm_state_dtype = mamba_state_buffer_config.ssm_state_dtype + self.ssm_state_shape = mamba_state_buffer_config.ssm_state_shape + + self.init_mamba_state_buffer() + + # allocate kv buffer pool secondly. + super().__init__(size, dtype, num_kv_heads, head_dim, self.full_attn_layer_num, always_copy, mem_fraction) + + def init_mamba_state_buffer(self): + self.conv_state_buffers = torch.zeros( + (self.linear_attn_layer_num, self.max_req_num + 1, *self.conv_state_shape), + dtype=self.conv_state_dtype, + device="cuda", + ) + self.ssm_state_buffers = torch.zeros( + (self.linear_attn_layer_num, self.max_req_num + 1, *self.ssm_state_shape), + dtype=self.ssm_state_dtype, + device="cuda", + ) + + @override + def get_kv_buffer(self, layer_index): + assert (layer_index + 1) % self.full_attention_interval == 0, "layer_index is not full attention layer" + return self.kv_buffer[layer_index // self.full_attention_interval] + + def get_mamba_state_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]: + assert (layer_index + 1) % self.full_attention_interval != 0, "layer_index is not linear attention layer" + real_layer_index = layer_index - layer_index // self.full_attention_interval + conv_states = self.conv_state_buffers[real_layer_index] + ssm_states = self.ssm_state_buffers[real_layer_index] + return conv_states, ssm_states + + def free_mamba_state_buffer(self, req_indexes: List[int]): + self.conv_state_buffers[:, req_indexes, ...] = 0 + self.ssm_state_buffers[:, req_indexes, ...] = 0 + return + + @override + def _free_buffers(self): + super()._free_buffers() + self.conv_state_buffers = None + self.ssm_state_buffers = None + + @override + def init_buffers(self): + super().init_buffers() + if self.conv_state_buffers is None and self.ssm_state_buffers is None: + self.init_mamba_state_buffer() + return diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py new file mode 100644 index 000000000..85c5bb04c --- /dev/null +++ b/lightllm/models/qwen3next/model.py @@ -0,0 +1,77 @@ +import torch +from typing_extensions import override +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import Qwen3NextTransformerLayerWeight +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import Qwen3NextTransformerLayerInfer +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager, MambaStateBufferConfig + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_next") +class Qwen3NextTpPartModel(Qwen3MOEModel): + # weight class + transformer_weight_class = Qwen3NextTransformerLayerWeight + + # infer class + transformer_layer_infer_class = Qwen3NextTransformerLayerInfer + post_layer_infer_class = Qwen3NextPostLayerInfer + + def __init__(self, kvargs) -> None: + super().__init__(kvargs) + + @override + def autotune_layers(self): + return self.config["full_attention_interval"] + + @override + def _init_config(self): + super()._init_config() + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + @override + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + + @override + def _init_mem_manager(self): + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + mtp_step = get_env_start_args().mtp_step + self.num_linear_k_heads = self.config["linear_num_key_heads"] + self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.head_linear_k_dim = self.config["linear_key_head_dim"] + self.head_linear_v_dim = self.config["linear_value_head_dim"] + conv_kernel_size = self.config["linear_conv_kernel_dim"] + + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + + mamba_state_buffer_config = MambaStateBufferConfig( + conv_state_dtype=self.data_type, + conv_state_shape=(conv_kernel_size - 1 + mtp_step, conv_dim // self.tp_world_size_), + ssm_state_dtype=self.data_type, + ssm_state_shape=( + self.num_linear_v_heads // self.tp_world_size_, + self.head_linear_k_dim, + self.head_linear_v_dim, + ), + ) + + self.mem_manager = Qwen3NextMemoryManager( + size=self.max_total_token_num, + dtype=self.data_type, + num_kv_heads=self.num_kv_heads, + head_dim=self.config["head_dim"], + layer_num=self.config["n_layer"], + full_attention_interval=self.config["full_attention_interval"], + max_req_num=self.max_req_num, + mamba_state_buffer_config=mamba_state_buffer_config, + mem_fraction=self.mem_fraction, + ) diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py new file mode 100644 index 000000000..202ce7460 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -0,0 +1,1057 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.11.0rc1/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py + +from typing import Optional, Union + +import numpy as np +import torch + +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + batch_ptr, + token_chunk_offset_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + batch: tl.int32, # actually padded_batch + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.load(batch_ptr + tl.program_id(0)) + chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + # base of the sequence + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = ( + conv_states_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + # col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + # if KERNEL_WIDTH >= 5: # STRATEGY1 + # col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] & (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = conv_states_base[None, :] + (idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + # need this due to the bug in tl.where not enforcing this + # when data is the result of another tl.load + tl.debug_barrier() + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = ( + conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = ( + conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + # col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token) * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + out: torch.Tensor, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=True, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch + sequences are concatenated from left to right for varlen + weight: (dim, width) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + out: same shape as `x` + """ + if isinstance(activation, bool) and activation: + activation = "silu" + + args = None + # Store original dtype to cast back at the end + original_x_dtype = x.dtype + x = x.to(conv_states.dtype) + if metadata is not None: + nums_dict = metadata.nums_dict + args = nums_dict + batch_ptr = metadata.batch_ptr + token_chunk_offset_ptr = metadata.token_chunk_offset_ptr + else: + seqlens = np.diff(query_start_loc.to("cpu")) + args = seqlens + MAX_NUM_PROGRAMS = 1024 + + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device + ) # tracking which seq-idx the Triton program is handling + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device + ) # tracking BLOCK_M-based index in the sequence the Triton program is handling + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + padded_batch = query_start_loc.size(0) - 1 + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ), f"{num_cache_lines} {dim} {width} {conv_states.shape}" + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch,) + assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + if metadata is None: + + def num_program(META, seqlens): + tot = 0 + + mlist = [] + offsetlist = [] # type: ignore + + nums = -(-seqlens // META["BLOCK_M"]) + + tot = nums.sum().item() + mlist = np.repeat(np.arange(len(nums)), nums) + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) # chunk-idx if a sequence is split into multiple chunks + + if META["batch_ptr"].nelement() < len(mlist): + newlen = len(mlist) + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= len(mlist): + META["batch_ptr"][0 : len(mlist)].copy_(torch.from_numpy(np.array(mlist))) + META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(torch.from_numpy(np.array(offsetlist))) + + META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) + META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(META["x_ptr"].device) + return tot + + else: + + def num_program(META, nums_dict): + tot = nums_dict[META["BLOCK_M"]]["tot"] + + mlist = nums_dict[META["BLOCK_M"]]["mlist"] + mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] + + offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] + + if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: + META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]]["token_chunk_offset_ptr"] + else: + if META["batch_ptr"].nelement() < mlist_len: + newlen = mlist_len + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= mlist_len: + META["batch_ptr"][0:mlist_len].copy_(mlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) + return tot + + def grid(META): + return ( + num_program(META, args), + triton.cdiv(dim, META["BLOCK_N"]), + ) + + if batch_ptr.device != x.device: + batch_ptr = batch_ptr.to(x.device) + token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + batch_ptr, + token_chunk_offset_ptr, + out, + # Matrix dimensions + padded_batch, + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + # launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out.to(original_x_dtype) + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + max_query_len: int = -1, + pad_slot_id: int = PAD_SLOT_ID, + validate_data=False, +): + """ + x: Input tensor which can take the following shapes: + + - `[batch, dim]` - single token prediction + - `[batch, dim, seqlen]` - single or multiple tokens prediction + - `[num_tokens, dim]` - continuous batching, where num_tokens is + the total tokens of all sequences in that batch + + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + num_accepted_tokens: (batch,), dtype int32 + If not None, it indicates the number of accepted tokens for each + sequence in the batch. + This is used in speculative decoding, where the conv_state is updated + in a sliding window manner. + query_start_loc: (batch + 1,) int32 + If not None, the inputs is given in a varlen fashion and this indicates + the starting index of each sequence in the batch. + max_query_len: int + If query_start_loc is not None, this indicates the maximum query + length in the batch. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + seqlen = max_query_len + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch,) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (dim, cu_seqlen) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py new file mode 100644 index 000000000..0e89cf9f7 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py new file mode 100644 index 000000000..cd3b0962a --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py new file mode 100644 index 000000000..22c81ae63 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch" + f" seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py new file mode 100644 index 000000000..f20c95d90 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp +from .utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_G"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = ( + tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + if SAVE_NEW_VALUE + else None + ) + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py new file mode 100644 index 000000000..73c2e1f19 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + {"USE_G": lambda args: args["g"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + if FLA_GDN_FIX_BT: + BT = 64 + else: + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 000000000..aa545e8ec --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp + + +@triton.heuristics( + {"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "USE_G": lambda args: args["g_cumsum"] is not None} +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * exp(b_g_diff) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py new file mode 100644 index 000000000..9cd6a6545 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, S=S, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: + if not head_first and g.shape[1] < g.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch" + f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py new file mode 100644 index 000000000..4ff18d4f6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ( + ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token + ) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py new file mode 100644 index 000000000..8b1d59fc6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py new file mode 100644 index 000000000..7225cd4ae --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.autotune(configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], key=["D"]) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D"], +) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py new file mode 100644 index 000000000..ec0999455 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import triton +import triton.language as tl + + +@triton.jit +def div_normal(x, y): + return x / y + + +div = div_normal +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +if not hasattr(tl, "gather"): + + @triton.jit + def gather(src, index, axis, _builder=None): + # This is a fallback implementation when tl.gather is not supported + # In order to pass triton compiler, there is no actual gather operation + return src + +else: + gather = tl.gather diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py new file mode 100644 index 000000000..46e4d5082 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import input_guard + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["BT"], +) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee") + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee") + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)) + tl.store(p_Ai_12, fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_13, fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_14, fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_23, fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_24, fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_34, fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@input_guard +def solve_tril( + A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + + chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py new file mode 100644 index 000000000..d8f29f287 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from enum import Enum +from typing import Any, Callable, Literal, Optional + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != "hip" else "cuda" +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py new file mode 100644 index 000000000..dec8d2ffc --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py new file mode 100644 index 000000000..99a5e2f70 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -0,0 +1,83 @@ +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def _get_fused_gdn_gating_configs(): + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [8, 16, 32, 64] for nw in [1, 2, 4]] + + +def _get_fused_gdn_gating_static_key(a: torch.Tensor): + # group by head size and input dtype + return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)} + + +@autotune( + kernel_name="fused_gdn_gating:v1", + configs_gen_func=_get_fused_gdn_gating_configs, + static_key_func=_get_fused_gdn_gating_static_key, + run_key_func=lambda a: a.shape[0], +) +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + run_config: dict | None = None, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + + # default heuristic when autotune is disabled + if not run_config: + # choose the largest block size that does not exceed num_heads + candidate_blk = [8, 16, 32, 64] + blk_heads = max([c for c in candidate_blk if c <= max(8, num_heads)] or [8]) + run_config = {"BLK_HEADS": blk_heads, "num_warps": 1} + + BLK_HEADS = run_config["BLK_HEADS"] + num_warps = run_config.get("num_warps", 1) + + grid = (batch, seq_len, triton.cdiv(num_heads, BLK_HEADS)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid]( + g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + BLK_HEADS, + num_warps=num_warps, + ) + return g diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py new file mode 100644 index 000000000..89db5e00c --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py @@ -0,0 +1,174 @@ +import triton +import triton.language as tl +import torch +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + } +) +@triton.jit +def gated_rmsnorm_forward_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch (required, not optional) + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + Z += row * stride_z_row + group * N + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute variance (RMS norm doesn't use mean) + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + # RMS norm: compute variance directly without mean subtraction + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + # RMS norm: normalize without mean subtraction + x_hat = x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _get_gated_rmsnorm_configs(): + """Generate configurations for autotuning gated RMSNorm kernel.""" + configs = [] + # Different BLOCK_N sizes (powers of 2) + for block_n in [64, 128, 256, 512, 1024, 2048, 4096]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + # Skip configurations that are likely to be inefficient + if block_n >= 2048 and num_warps > 4: + continue + if block_n <= 128 and num_warps > 2: + continue + configs.append({"BLOCK_N": block_n, "num_warps": num_warps}) + return configs + + +def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + M, N = x.shape + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(weight.dtype), + "N": N, + "has_bias": bias is not None, + } + + +@autotune( + kernel_name="gated_rmsnorm_forward:v1", + configs_gen_func=_get_gated_rmsnorm_configs, + static_key_func=_get_gated_rmsnorm_static_key, + run_key_func=lambda x: x.shape[0], +) +def gated_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + run_config: dict = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + # z is required for gated_rmsnorm + assert z is not None, "z cannot be None for gated_rmsnorm_forward" + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + # For RMS norm, we still need rstd for the kernel + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + + # Validate BLOCK_N against group_size + if group_size > BLOCK_N: + # Fall back to largest valid BLOCK_N + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M, ngroups) + gated_rmsnorm_forward_kernel[grid]( + x, + out, + weight, + bias, + z, + rstd, + x.stride(0), + out.stride(0), + z.stride(0), + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + num_warps=num_warps, + ) + return out diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py new file mode 100644 index 000000000..210e78db1 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py @@ -0,0 +1,144 @@ +import torch + +import triton +import triton.language as tl +import os + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _gemma_rmsnorm_fwd_kernel( + x_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + y_ptr = y_ptr + row * y_stride0 + + _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) + _sum += x * x + + var = tl.sum(_sum, axis=0) / N + rstd = 1 / tl.sqrt(var + EPS) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + w = w + 1.0 + y = x_hat * w + # Write output + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_gemma_rmsnorm_configs(): + """Generate configurations for autotuning gemma RMSNorm kernel.""" + configs = [] + # Different BLOCK_SIZE values (powers of 2) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + for num_stages in [1, 2, 3, 4, 5]: + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": num_stages}) + return configs + + +def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="gemma_rmsnorm_forward:v1", + configs_gen_func=_get_gemma_rmsnorm_configs, + static_key_func=_get_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], +) +def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): + # Inplace gemma RMS Norm + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + y_arg = y.view(-1, N) + + M, _ = x_arg.shape + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _gemma_rmsnorm_fwd_kernel[(M,)]( + x_arg, + w, + y_arg, + x_stride0=x.stride(0), + x_stride1=x.stride(1), + y_stride0=y.stride(0), + y_stride1=y.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _gemma_rmsnorm_fwd_torch(x, weight, eps): + original_dtype = x.dtype + x = x.to(torch.float32) + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + x = x * (1.0 + weight.float()) + return x.to(original_dtype) + + +def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = gemma_rmsnorm_forward(x, weight, eps) + y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + # Use appropriate tolerance based on dtype + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) + return