1+ from typing import Optional
2+
3+ import pytest
4+ import torch
5+ import torch .nn .functional as F
6+
7+ from vllm_ascend .ops .triton .mamba .causal_conv1d import (PAD_SLOT_ID ,
8+ causal_conv1d_fn )
9+
10+
11+ def validate_cmp (y_cal , y_ref , dtype , device = 'npu' ):
12+ y_cal = y_cal .to (device )
13+ y_ref = y_ref .to (device )
14+ if dtype == torch .float16 :
15+ torch .testing .assert_close (y_ref ,
16+ y_cal ,
17+ rtol = 3e-03 ,
18+ atol = 1e-02 ,
19+ equal_nan = True )
20+ elif dtype == torch .bfloat16 :
21+ torch .testing .assert_close (y_ref ,
22+ y_cal ,
23+ rtol = 1e-02 ,
24+ atol = 1e-02 ,
25+ equal_nan = True )
26+ elif dtype == torch .float32 :
27+ torch .testing .assert_close (y_ref ,
28+ y_cal ,
29+ rtol = 1e-03 ,
30+ atol = 4e-03 ,
31+ equal_nan = True )
32+ elif dtype == torch .int32 or dtype == torch .int64 or dtype == torch .int16 or dtype == torch .int8 or dtype == torch .uint32 :
33+ assert torch .equal (y_cal , y_ref )
34+ elif dtype == torch .bool :
35+ assert torch .equal (y_cal , y_ref )
36+ else :
37+ raise ValueError (
38+ 'Invalid parameter \" dtype\" is found : {}' .format (dtype ))
39+
40+
41+ def causal_conv1d_ref (
42+ x : torch .Tensor ,
43+ weight : torch .Tensor ,
44+ bias : Optional [torch .Tensor ] = None ,
45+ initial_states : Optional [torch .Tensor ] = None ,
46+ return_final_states : bool = False ,
47+ final_states_out : Optional [torch .Tensor ] = None ,
48+ activation : Optional [str ] = "silu" ,
49+ ):
50+ """
51+ x: (batch, dim, seqlen)
52+ weight: (dim, width)
53+ bias: (dim,)
54+ initial_states: (batch, dim, width - 1)
55+ final_states_out: (batch, dim, width - 1)
56+ out: (batch, dim, seqlen)
57+ """
58+ if activation not in [None , "silu" , "swish" ]:
59+ raise NotImplementedError ("activation must be None, silu, or swish" )
60+ dtype_in = x .dtype
61+ x = x .to (weight .dtype )
62+ seqlen = x .shape [- 1 ]
63+ dim , width = weight .shape
64+
65+ if initial_states is None :
66+ out = F .conv1d (x ,
67+ weight .unsqueeze (1 ),
68+ bias ,
69+ padding = width - 1 ,
70+ groups = dim )
71+ else :
72+ x = torch .cat ([initial_states , x ], dim = - 1 )
73+ out = F .conv1d (x , weight .unsqueeze (1 ), bias , padding = 0 , groups = dim )
74+ out = out [..., :seqlen ]
75+
76+ if return_final_states :
77+ final_states = F .pad (x , (width - 1 - x .shape [- 1 ], 0 )).to (
78+ dtype_in ) # (batch, dim, width - 1)
79+ if final_states_out is not None :
80+ final_states_out .copy_ (final_states )
81+ else :
82+ final_states_out = final_states
83+ out = (out if activation is None else F .silu (out )).to (dtype = dtype_in )
84+ return (out , None ) if not return_final_states else (out , final_states_out )
85+
86+
87+ def causal_conv1d_fn_pytorch (
88+ x : torch .Tensor ,
89+ weight : torch .Tensor ,
90+ query_start_loc : torch .Tensor ,
91+ cache_indices : torch .Tensor ,
92+ has_initial_state : torch .Tensor ,
93+ conv_states : torch .Tensor ,
94+ bias : Optional [torch .Tensor ] = None ,
95+ activation : Optional [str ] = "silu" ,
96+ pad_slot_id : int = PAD_SLOT_ID ,
97+ ):
98+ """
99+ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
100+ sequences are concatenated from left to right for varlen
101+ weight: (dim, width)
102+ bias: (dim,)
103+ query_start_loc: (batch + 1) int32
104+ The cumulative sequence lengths of the sequences in
105+ the batch, used to index into sequence. prepended by 0.
106+ for example: query_start_loc = torch.Tensor([0,10,16,17]),
107+ x.shape=(dim,17)
108+ cache_indices: (batch) int32
109+ indicates the corresponding state index,
110+ like so: conv_state = conv_states[cache_indices[batch_id]]
111+ has_initial_state: (batch) bool
112+ indicates whether should the kernel take the current state as initial
113+ state for the calculations
114+ conv_states: (...,dim,width - 1) itype
115+ updated inplace if provided
116+ activation: either None or "silu" or "swish"
117+ pad_slot_id: int
118+ if cache_indices is passed, lets the kernel identify padded
119+ entries that will not be processed,
120+ for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
121+ in this case, the kernel will not process entries at
122+ indices 0 and 3
123+ out: (batch, dim, seqlen)
124+ """
125+ if activation not in [None , "silu" , "swish" ]:
126+ raise NotImplementedError ("activation must be None, silu, or swish" )
127+ if x .stride (- 1 ) != 1 :
128+ x = x .contiguous ()
129+ bias = bias .contiguous () if bias is not None else None
130+
131+ out_ref = []
132+ out_ref_b = []
133+ seqlens = query_start_loc [1 :] - query_start_loc [:- 1 ]
134+ seqlens = seqlens .tolist ()
135+ splits = torch .split (x , seqlens , dim = - 1 )
136+ width = weight .shape [1 ]
137+
138+ for i in range (len (seqlens )):
139+ x_s = splits [i ]
140+ if cache_indices [i ] == PAD_SLOT_ID :
141+ continue
142+ out_ref_b .append (
143+ causal_conv1d_ref (
144+ x_s ,
145+ weight ,
146+ bias ,
147+ activation = activation ,
148+ return_final_states = True ,
149+ final_states_out = conv_states [cache_indices [i ]][..., :(
150+ width - 1 )].unsqueeze (0 ),
151+ initial_states = conv_states [cache_indices [i ]][..., :(width - 1 )]
152+ if has_initial_state [i ] else None ))
153+ out_ref .append (torch .cat ([t [0 ] for t in out_ref_b ], dim = - 1 ))
154+ out_ref_tensor = torch .cat (out_ref , dim = 0 )
155+ return out_ref_tensor
156+
157+
158+ @pytest .mark .parametrize ('has_initial_state' , [False , True ])
159+ @pytest .mark .parametrize ('itype' ,
160+ [torch .float32 , torch .float16 , torch .bfloat16 ])
161+ @pytest .mark .parametrize ('silu_activation' , [False , True ])
162+ @pytest .mark .parametrize ('has_bias' , [False , True ])
163+ @pytest .mark .parametrize ('seq_len' , [[
164+ 1 , 2 , 8 , 16 , 32 , 64 , 128 , 129 , 130 , 151 , 256 , 372 , 512 , 784 , 1024 , 1134 ,
165+ 2048 , 4096
166+ ]])
167+ @pytest .mark .parametrize ('extra_state_len' , [0 , 2 ])
168+ @pytest .mark .parametrize ('width' , [2 , 3 , 4 ])
169+ @pytest .mark .parametrize ('dim' , [64 , 4160 ])
170+ def test_causal_conv1d (dim , width , extra_state_len , seq_len , has_bias ,
171+ silu_activation , itype , has_initial_state ):
172+
173+ torch .random .manual_seed (0 )
174+
175+ device = "npu"
176+ cu_seqlen , num_seq = sum (seq_len ), len (seq_len )
177+ state_len = width - 1 + extra_state_len
178+
179+ x = torch .randn (cu_seqlen , dim , device = device , dtype = itype ).transpose (0 , 1 )
180+ weight = torch .randn (dim , width , device = device , dtype = itype )
181+ query_start_loc = torch .cumsum (torch .tensor ([0 ] + seq_len ,
182+ device = device ,
183+ dtype = torch .int32 ),
184+ dim = 0 )
185+ cache_indices = torch .arange (num_seq , device = device , dtype = torch .int32 )
186+ has_initial_state_tensor = torch .tensor ([has_initial_state ] * num_seq ,
187+ device = device ,
188+ dtype = torch .bool )
189+ activation = None if not silu_activation else "silu"
190+
191+ if has_initial_state :
192+ conv_states = torch .randn ((num_seq , state_len , dim ),
193+ device = device ,
194+ dtype = itype ).transpose (- 1 , - 2 )
195+ conv_states_ref = torch .randn (
196+ (num_seq , state_len , dim ), device = device ,
197+ dtype = itype ).transpose (- 1 , - 2 ).copy_ (conv_states )
198+ else :
199+ conv_states = torch .zeros ((num_seq , state_len , dim ),
200+ device = device ,
201+ dtype = itype ).transpose (- 1 , - 2 )
202+ conv_states_ref = torch .zeros ((num_seq , state_len , dim ),
203+ device = device ,
204+ dtype = itype ).transpose (- 1 , - 2 )
205+
206+ if has_bias :
207+ bias = torch .randn (dim , device = device , dtype = itype )
208+ else :
209+ bias = None
210+
211+ out_ref = causal_conv1d_fn_pytorch (
212+ x ,
213+ weight ,
214+ bias = bias ,
215+ activation = activation ,
216+ conv_states = conv_states_ref ,
217+ has_initial_state = has_initial_state_tensor ,
218+ cache_indices = cache_indices ,
219+ query_start_loc = query_start_loc )
220+ out = causal_conv1d_fn (x ,
221+ weight ,
222+ bias = bias ,
223+ activation = activation ,
224+ conv_states = conv_states ,
225+ has_initial_state = has_initial_state_tensor ,
226+ cache_indices = cache_indices ,
227+ query_start_loc = query_start_loc )
228+
229+ validate_cmp (out , out_ref , itype )
230+ validate_cmp (conv_states , conv_states_ref , itype )
0 commit comments