Skip to content

Commit 78bf211

Browse files
authored
[OPS] support triton causal_conv1d_fn ops (#4119)
### What this PR does / why we need it? Support triton causal_conv1d_fn ops. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: QilaiZhang <[email protected]>
1 parent eac72f5 commit 78bf211

File tree

3 files changed

+636
-97
lines changed

3 files changed

+636
-97
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from typing import Optional
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
7+
from vllm_ascend.ops.triton.mamba.causal_conv1d import (PAD_SLOT_ID,
8+
causal_conv1d_fn)
9+
10+
11+
def validate_cmp(y_cal, y_ref, dtype, device='npu'):
12+
y_cal = y_cal.to(device)
13+
y_ref = y_ref.to(device)
14+
if dtype == torch.float16:
15+
torch.testing.assert_close(y_ref,
16+
y_cal,
17+
rtol=3e-03,
18+
atol=1e-02,
19+
equal_nan=True)
20+
elif dtype == torch.bfloat16:
21+
torch.testing.assert_close(y_ref,
22+
y_cal,
23+
rtol=1e-02,
24+
atol=1e-02,
25+
equal_nan=True)
26+
elif dtype == torch.float32:
27+
torch.testing.assert_close(y_ref,
28+
y_cal,
29+
rtol=1e-03,
30+
atol=4e-03,
31+
equal_nan=True)
32+
elif dtype == torch.int32 or dtype == torch.int64 or dtype == torch.int16 or dtype == torch.int8 or dtype == torch.uint32:
33+
assert torch.equal(y_cal, y_ref)
34+
elif dtype == torch.bool:
35+
assert torch.equal(y_cal, y_ref)
36+
else:
37+
raise ValueError(
38+
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
39+
40+
41+
def causal_conv1d_ref(
42+
x: torch.Tensor,
43+
weight: torch.Tensor,
44+
bias: Optional[torch.Tensor] = None,
45+
initial_states: Optional[torch.Tensor] = None,
46+
return_final_states: bool = False,
47+
final_states_out: Optional[torch.Tensor] = None,
48+
activation: Optional[str] = "silu",
49+
):
50+
"""
51+
x: (batch, dim, seqlen)
52+
weight: (dim, width)
53+
bias: (dim,)
54+
initial_states: (batch, dim, width - 1)
55+
final_states_out: (batch, dim, width - 1)
56+
out: (batch, dim, seqlen)
57+
"""
58+
if activation not in [None, "silu", "swish"]:
59+
raise NotImplementedError("activation must be None, silu, or swish")
60+
dtype_in = x.dtype
61+
x = x.to(weight.dtype)
62+
seqlen = x.shape[-1]
63+
dim, width = weight.shape
64+
65+
if initial_states is None:
66+
out = F.conv1d(x,
67+
weight.unsqueeze(1),
68+
bias,
69+
padding=width - 1,
70+
groups=dim)
71+
else:
72+
x = torch.cat([initial_states, x], dim=-1)
73+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
74+
out = out[..., :seqlen]
75+
76+
if return_final_states:
77+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
78+
dtype_in) # (batch, dim, width - 1)
79+
if final_states_out is not None:
80+
final_states_out.copy_(final_states)
81+
else:
82+
final_states_out = final_states
83+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
84+
return (out, None) if not return_final_states else (out, final_states_out)
85+
86+
87+
def causal_conv1d_fn_pytorch(
88+
x: torch.Tensor,
89+
weight: torch.Tensor,
90+
query_start_loc: torch.Tensor,
91+
cache_indices: torch.Tensor,
92+
has_initial_state: torch.Tensor,
93+
conv_states: torch.Tensor,
94+
bias: Optional[torch.Tensor] = None,
95+
activation: Optional[str] = "silu",
96+
pad_slot_id: int = PAD_SLOT_ID,
97+
):
98+
"""
99+
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
100+
sequences are concatenated from left to right for varlen
101+
weight: (dim, width)
102+
bias: (dim,)
103+
query_start_loc: (batch + 1) int32
104+
The cumulative sequence lengths of the sequences in
105+
the batch, used to index into sequence. prepended by 0.
106+
for example: query_start_loc = torch.Tensor([0,10,16,17]),
107+
x.shape=(dim,17)
108+
cache_indices: (batch) int32
109+
indicates the corresponding state index,
110+
like so: conv_state = conv_states[cache_indices[batch_id]]
111+
has_initial_state: (batch) bool
112+
indicates whether should the kernel take the current state as initial
113+
state for the calculations
114+
conv_states: (...,dim,width - 1) itype
115+
updated inplace if provided
116+
activation: either None or "silu" or "swish"
117+
pad_slot_id: int
118+
if cache_indices is passed, lets the kernel identify padded
119+
entries that will not be processed,
120+
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
121+
in this case, the kernel will not process entries at
122+
indices 0 and 3
123+
out: (batch, dim, seqlen)
124+
"""
125+
if activation not in [None, "silu", "swish"]:
126+
raise NotImplementedError("activation must be None, silu, or swish")
127+
if x.stride(-1) != 1:
128+
x = x.contiguous()
129+
bias = bias.contiguous() if bias is not None else None
130+
131+
out_ref = []
132+
out_ref_b = []
133+
seqlens = query_start_loc[1:] - query_start_loc[:-1]
134+
seqlens = seqlens.tolist()
135+
splits = torch.split(x, seqlens, dim=-1)
136+
width = weight.shape[1]
137+
138+
for i in range(len(seqlens)):
139+
x_s = splits[i]
140+
if cache_indices[i] == PAD_SLOT_ID:
141+
continue
142+
out_ref_b.append(
143+
causal_conv1d_ref(
144+
x_s,
145+
weight,
146+
bias,
147+
activation=activation,
148+
return_final_states=True,
149+
final_states_out=conv_states[cache_indices[i]][..., :(
150+
width - 1)].unsqueeze(0),
151+
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
152+
if has_initial_state[i] else None))
153+
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
154+
out_ref_tensor = torch.cat(out_ref, dim=0)
155+
return out_ref_tensor
156+
157+
158+
@pytest.mark.parametrize('has_initial_state', [False, True])
159+
@pytest.mark.parametrize('itype',
160+
[torch.float32, torch.float16, torch.bfloat16])
161+
@pytest.mark.parametrize('silu_activation', [False, True])
162+
@pytest.mark.parametrize('has_bias', [False, True])
163+
@pytest.mark.parametrize('seq_len', [[
164+
1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134,
165+
2048, 4096
166+
]])
167+
@pytest.mark.parametrize('extra_state_len', [0, 2])
168+
@pytest.mark.parametrize('width', [2, 3, 4])
169+
@pytest.mark.parametrize('dim', [64, 4160])
170+
def test_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
171+
silu_activation, itype, has_initial_state):
172+
173+
torch.random.manual_seed(0)
174+
175+
device = "npu"
176+
cu_seqlen, num_seq = sum(seq_len), len(seq_len)
177+
state_len = width - 1 + extra_state_len
178+
179+
x = torch.randn(cu_seqlen, dim, device=device, dtype=itype).transpose(0, 1)
180+
weight = torch.randn(dim, width, device=device, dtype=itype)
181+
query_start_loc = torch.cumsum(torch.tensor([0] + seq_len,
182+
device=device,
183+
dtype=torch.int32),
184+
dim=0)
185+
cache_indices = torch.arange(num_seq, device=device, dtype=torch.int32)
186+
has_initial_state_tensor = torch.tensor([has_initial_state] * num_seq,
187+
device=device,
188+
dtype=torch.bool)
189+
activation = None if not silu_activation else "silu"
190+
191+
if has_initial_state:
192+
conv_states = torch.randn((num_seq, state_len, dim),
193+
device=device,
194+
dtype=itype).transpose(-1, -2)
195+
conv_states_ref = torch.randn(
196+
(num_seq, state_len, dim), device=device,
197+
dtype=itype).transpose(-1, -2).copy_(conv_states)
198+
else:
199+
conv_states = torch.zeros((num_seq, state_len, dim),
200+
device=device,
201+
dtype=itype).transpose(-1, -2)
202+
conv_states_ref = torch.zeros((num_seq, state_len, dim),
203+
device=device,
204+
dtype=itype).transpose(-1, -2)
205+
206+
if has_bias:
207+
bias = torch.randn(dim, device=device, dtype=itype)
208+
else:
209+
bias = None
210+
211+
out_ref = causal_conv1d_fn_pytorch(
212+
x,
213+
weight,
214+
bias=bias,
215+
activation=activation,
216+
conv_states=conv_states_ref,
217+
has_initial_state=has_initial_state_tensor,
218+
cache_indices=cache_indices,
219+
query_start_loc=query_start_loc)
220+
out = causal_conv1d_fn(x,
221+
weight,
222+
bias=bias,
223+
activation=activation,
224+
conv_states=conv_states,
225+
has_initial_state=has_initial_state_tensor,
226+
cache_indices=cache_indices,
227+
query_start_loc=query_start_loc)
228+
229+
validate_cmp(out, out_ref, itype)
230+
validate_cmp(conv_states, conv_states_ref, itype)

0 commit comments

Comments
 (0)