Skip to content

Commit a2c834e

Browse files
committed
support triton chunk_gated_delta_rule ops
Signed-off-by: shiyuan680 <[email protected]>
1 parent f7d1f73 commit a2c834e

File tree

16 files changed

+1644
-150
lines changed

16 files changed

+1644
-150
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import unittest
2+
from unittest.mock import patch
3+
4+
import pytest
5+
import torch
6+
7+
from tests.ut.base import PytestBase
8+
from vllm_ascend.ops.fla.chunk import chunk_gated_delta_rule
9+
10+
11+
@pytest.fixture
12+
def mock_moe_env():
13+
14+
with patch("torch_npu.npu_moe_finalize_routing",
15+
return_value=(torch.randn(1, 17, 8, 128, dtype=torch.bfloat16),
16+
torch.randn(3, 8, 128, 128,
17+
dtype=torch.bfloat16))):
18+
yield
19+
20+
21+
class TestChunkGatedDeltaRule(PytestBase):
22+
23+
def test_triton_fusion_ops(self, mock_moe_env):
24+
q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
25+
k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
26+
v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu()
27+
g = torch.randn(1, 17, 8, dtype=torch.float32).npu()
28+
beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu()
29+
initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu()
30+
q_start_loc = torch.range(0, 3, dtype=torch.int).npu()
31+
32+
(
33+
core_attn_out_non_spec,
34+
last_recurrent_state,
35+
) = chunk_gated_delta_rule(q=q,
36+
k=k,
37+
v=v,
38+
g=g,
39+
beta=beta,
40+
initial_state=initial_state,
41+
output_final_state=True,
42+
cu_seqlens=q_start_loc,
43+
head_first=False,
44+
use_qk_l2norm_in_kernel=True)
45+
46+
assert core_attn_out_non_spec.shape == (1, 17, 8, 128)
47+
assert last_recurrent_state.shape == (3, 8, 128, 128)
48+
49+
50+
if __name__ == '__main__':
51+
unittest.main()

vllm_ascend/models/qwen3_next.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -365,50 +365,19 @@ def _forward(
365365
non_spec_state_indices_tensor].contiguous()
366366
initial_state[~has_initial_state, ...] = 0
367367

368-
batch_size = initial_state.shape[0]
369-
core_attn_out = []
370-
last_recurrent_state = []
371-
372-
for b_idx in range(batch_size):
373-
start, end = non_spec_query_start_loc[
374-
b_idx], non_spec_query_start_loc[b_idx + 1]
375-
cur_q = query_non_spec[:, start:end, ...]
376-
cur_k = key_non_spec[:, start:end, ...]
377-
cur_v = value_non_spec[:, start:end, ...]
378-
cur_g = g_non_spec[:, start:end, ...]
379-
cur_b = beta_non_spec[:, start:end, ...]
380-
cur_state = initial_state[b_idx].unsqueeze(0)
381-
382-
(
383-
cur_core_attn_out_non_spec,
384-
cur_last_recurrent_state,
385-
) = chunk_gated_delta_rule(
386-
query=cur_q,
387-
key=cur_k,
388-
value=cur_v,
389-
g=cur_g,
390-
beta=cur_b,
391-
initial_state=cur_state,
392-
output_final_state=True,
393-
use_qk_l2norm_in_kernel=True,
394-
)
395-
396-
core_attn_out.append(cur_core_attn_out_non_spec)
397-
last_recurrent_state.append(cur_last_recurrent_state)
398-
399-
tar_dtype = core_attn_out[0].dtype
400-
tar_device = core_attn_out[0].device
401-
tar_shape = list(core_attn_out[0].shape)
402-
tar_shape[1] = non_spec_query_start_loc[-1]
403-
core_attn_out_non_spec = torch.empty(tar_shape,
404-
dtype=tar_dtype,
405-
device=tar_device)
406-
for b_idx in range(batch_size):
407-
cur_core_attn_out = core_attn_out[b_idx]
408-
start, end = non_spec_query_start_loc[
409-
b_idx], non_spec_query_start_loc[b_idx + 1]
410-
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
411-
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
368+
(
369+
core_attn_out_non_spec,
370+
last_recurrent_state,
371+
) = chunk_gated_delta_rule(q=query_non_spec,
372+
k=key_non_spec,
373+
v=value_non_spec,
374+
g=g_non_spec,
375+
beta=beta_non_spec,
376+
initial_state=initial_state,
377+
output_final_state=True,
378+
cu_seqlens=non_spec_query_start_loc,
379+
head_first=False,
380+
use_qk_l2norm_in_kernel=True)
412381

413382
# Init cache
414383
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(

vllm_ascend/ops/fla/__init__.py

Whitespace-only changes.

vllm_ascend/ops/fla/chunk.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4+
#
5+
# This file contains code copied from the flash-linear-attention project.
6+
# The original source code was licensed under the MIT license and included
7+
# the following copyright notice:
8+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9+
# ruff: noqa: E501
10+
# mypy: ignore-errors
11+
import warnings
12+
from typing import Optional
13+
14+
import torch
15+
from einops import rearrange
16+
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
17+
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
18+
19+
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
20+
from .chunk_o import chunk_fwd_o
21+
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
22+
from .cumsum import chunk_local_cumsum
23+
from .solve_tril import solve_tril
24+
from .utils import input_guard
25+
from .wy_fast import recompute_w_u_fwd
26+
27+
28+
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
29+
k: torch.Tensor,
30+
v: torch.Tensor,
31+
g: torch.Tensor,
32+
beta: torch.Tensor,
33+
scale: float,
34+
initial_state: torch.Tensor,
35+
output_final_state: bool,
36+
cu_seqlens: Optional[torch.LongTensor] = None):
37+
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
38+
# obtain WY representation. u is actually the new v.
39+
A = chunk_scaled_dot_kkt_fwd(k=k,
40+
beta=beta,
41+
g_cumsum=g,
42+
cu_seqlens=cu_seqlens,
43+
output_dtype=torch.float32)
44+
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
45+
w, u = recompute_w_u_fwd(
46+
k=k,
47+
v=v,
48+
beta=beta,
49+
A=A,
50+
g_cumsum=g,
51+
cu_seqlens=cu_seqlens,
52+
)
53+
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
54+
k=k,
55+
w=w,
56+
u=u,
57+
g=g,
58+
initial_state=initial_state,
59+
output_final_state=output_final_state,
60+
cu_seqlens=cu_seqlens,
61+
)
62+
o = chunk_fwd_o(
63+
q=q,
64+
k=k,
65+
v=v_new,
66+
h=h,
67+
g=g,
68+
scale=scale,
69+
cu_seqlens=cu_seqlens,
70+
)
71+
if SUPPRESS_LEVEL < 3:
72+
return g, o, A, final_state, None, None, None
73+
elif SUPPRESS_LEVEL >= 3:
74+
return g, o, A, final_state, w, h, v_new
75+
76+
77+
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
78+
79+
@staticmethod
80+
@input_guard
81+
def forward(ctx,
82+
q: torch.Tensor,
83+
k: torch.Tensor,
84+
v: torch.Tensor,
85+
g: torch.Tensor,
86+
beta: torch.Tensor,
87+
scale: float,
88+
initial_state: torch.Tensor,
89+
output_final_state: bool,
90+
cu_seqlens: Optional[torch.LongTensor] = None,
91+
use_qk_l2norm_in_kernel: bool = False):
92+
if use_qk_l2norm_in_kernel:
93+
q = l2norm_fwd(q)
94+
k = l2norm_fwd(k)
95+
96+
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
97+
q=q,
98+
k=k,
99+
v=v,
100+
g=g,
101+
beta=beta,
102+
scale=scale,
103+
initial_state=initial_state,
104+
output_final_state=output_final_state,
105+
cu_seqlens=cu_seqlens,
106+
)
107+
ctx.scale = scale
108+
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
109+
return o.to(q.dtype), final_state
110+
111+
112+
@torch.compiler.disable
113+
def chunk_gated_delta_rule(q: torch.Tensor,
114+
k: torch.Tensor,
115+
v: torch.Tensor,
116+
g: torch.Tensor,
117+
beta: torch.Tensor,
118+
scale: float = None,
119+
initial_state: torch.Tensor = None,
120+
output_final_state: bool = False,
121+
cu_seqlens: Optional[torch.LongTensor] = None,
122+
head_first: bool = False,
123+
use_qk_l2norm_in_kernel: bool = False):
124+
r"""
125+
Args:
126+
q (torch.Tensor):
127+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
128+
k (torch.Tensor):
129+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
130+
v (torch.Tensor):
131+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
132+
g (torch.Tensor):
133+
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
134+
beta (torch.Tensor):
135+
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
136+
scale (Optional[int]):
137+
Scale factor for the RetNet attention scores.
138+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
139+
initial_state (Optional[torch.Tensor]):
140+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
141+
For equal-length input sequences, `N` equals the batch size `B`.
142+
Default: `None`.
143+
output_final_state (Optional[bool]):
144+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
145+
cu_seqlens (torch.LongTensor):
146+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
147+
consistent with the FlashAttention API.
148+
head_first (Optional[bool]):
149+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
150+
Default: `False`.
151+
152+
Returns:
153+
o (torch.Tensor):
154+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
155+
final_state (torch.Tensor):
156+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
157+
158+
Examples::
159+
>>> import torch
160+
>>> import torch.nn.functional as F
161+
>>> from einops import rearrange
162+
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
163+
# inputs with equal lengths
164+
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
165+
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
166+
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
167+
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
168+
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
169+
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
170+
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
171+
>>> o, ht = chunk_gated_delta_rule(
172+
q, k, v, g, beta,
173+
initial_state=h0,
174+
output_final_state=True
175+
)
176+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
177+
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
178+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
179+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
180+
>>> o_var, ht_var = chunk_gated_delta_rule(
181+
q, k, v, g, beta,
182+
initial_state=h0,
183+
output_final_state=True,
184+
cu_seqlens=cu_seqlens
185+
)
186+
"""
187+
assert q.dtype == k.dtype == v.dtype
188+
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
189+
assert len(
190+
beta.shape
191+
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
192+
193+
if head_first:
194+
raise DeprecationWarning(
195+
"head_first is deprecated and will be removed in a future version. "
196+
"Please use head_first=False for now instead.",
197+
stacklevel=2)
198+
q, k, v, beta, g = map(
199+
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
200+
(q, k, v, beta, g))
201+
if not head_first and q.shape[1] < q.shape[2]:
202+
warnings.warn(
203+
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
204+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
205+
"when head_first=False was specified. "
206+
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
207+
stacklevel=2)
208+
if cu_seqlens is not None:
209+
if q.shape[0] != 1:
210+
raise ValueError(
211+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
212+
f"Please flatten variable-length inputs before processing.")
213+
if initial_state is not None and initial_state.shape[0] != len(
214+
cu_seqlens) - 1:
215+
raise ValueError(
216+
f"The number of initial states is expected to be equal to the number of input sequences, "
217+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
218+
)
219+
if scale is None:
220+
scale = k.shape[-1]**-0.5
221+
o, final_state = ChunkGatedDeltaRuleFunction.apply(
222+
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
223+
use_qk_l2norm_in_kernel)
224+
if head_first:
225+
o = rearrange(o, 'b t h ... -> b h t ...')
226+
return o, final_state

0 commit comments

Comments
 (0)