@@ -133,33 +133,133 @@ def _cutlass_blackwell_fmha_backward(
133133 )
134134
135135
136- def _cutlass_blackwell_fmha_gen (
136+ def _validate_decode_inputs (
137137 q : torch .Tensor ,
138138 k : torch .Tensor ,
139139 v : torch .Tensor ,
140- seqlen_kv : torch .Tensor ,
141- batch_idx : torch .Tensor ,
142- kernel_type : GenKernelType = GenKernelType .UMMA_I ,
140+ seqlen_kv : torch .Tensor | None ,
141+ ) -> None :
142+ assert seqlen_kv is not None , "seqlen_kv must be provided for decode"
143+ tensors = {"q" : q , "k" : k , "v" : v , "seqlen_kv" : seqlen_kv }
144+
145+ for name , tensor in tensors .items ():
146+ # assert tensor.is_contiguous(), f"{name} is not contiguous"
147+ assert tensor .is_cuda , f"{ name } must be on GPU"
148+
149+
150+ def _prepare_decode_inputs (
151+ q : torch .Tensor , k : torch .Tensor , v : torch .Tensor
152+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , int , bool , tuple [int , ...]]:
153+ """
154+ Prepare inputs for decode kernel by handling both varlen and batch formats.
155+
156+ Returns:
157+ - Reshaped q, k, v tensors in batch format [B, 1, H, D]
158+ - batch_size
159+ - needs_reshape_output flag
160+ - original_shape of q
161+ """
162+ original_shape = tuple (q .shape )
163+ needs_reshape_output = False
164+ batch_size = q .shape [0 ]
165+
166+ if q .dim () == 3 :
167+ # Varlen format: [total_queries, num_heads, head_dim]
168+ q = q .view (batch_size , 1 , q .shape [1 ], q .shape [2 ])
169+ needs_reshape_output = True
170+
171+ if q .dim () != 4 :
172+ raise ValueError (
173+ f"Invalid query shape: { q .shape } . Expected [B, 1, H, D] or [total_queries, H, D]"
174+ )
175+ assert q .shape [1 ] == 1 , "Kernel have sq=1"
176+
177+ k = k .view (batch_size , - 1 , k .shape [1 ], k .shape [2 ]) if k .dim () == 3 else k
178+ v = v .view (batch_size , - 1 , v .shape [1 ], v .shape [2 ]) if v .dim () == 3 else v
179+
180+ return q , k , v , batch_size , needs_reshape_output , original_shape
181+
182+
183+ def _create_decode_lse (
184+ out : torch .Tensor ,
185+ batch_size : int ,
186+ needs_reshape_output : bool ,
187+ q_shape : tuple [int , ...],
143188) -> torch .Tensor :
144- assert q .is_contiguous (), "q is not contiguous"
145- assert k .is_contiguous (), "k is not contiguous"
146- assert v .is_contiguous (), "v is not contiguous"
147- assert seqlen_kv .is_contiguous (), "seqlen_kv is not contiguous"
148- assert batch_idx .is_contiguous (), "batch_idx is not contiguous"
149- assert q .is_cuda , "q must be on GPU"
150- assert k .is_cuda , "k must be on GPU"
151- assert v .is_cuda , "v must be on GPU"
152- assert seqlen_kv .is_cuda , "seqlen_kv must be on GPU"
153- assert batch_idx .is_cuda , "batch_idx must be on GPU"
154- return torch .ops .fbgemm .fmha_gen_fwd (
189+ """
190+ Create dummy LSE tensor for decode output compatibility.
191+ Gen kernel doesn't return LSE, so we create a zero tensor.
192+ """
193+ if needs_reshape_output :
194+ # For varlen output format
195+ lse_shape = [batch_size , q_shape [- 1 ]] # [B, H]
196+ else :
197+ # For batch output format
198+ lse_shape = [batch_size , q_shape [- 2 ], q_shape [1 ]] # [B, H, 1]
199+
200+ return torch .zeros (* lse_shape , dtype = torch .float32 , device = out .device )
201+
202+
203+ def cutlass_blackwell_fmha_decode_forward (
204+ q : torch .Tensor ,
205+ k : torch .Tensor ,
206+ v : torch .Tensor ,
207+ seqlen_kv : torch .Tensor | None = None ,
208+ cu_seqlens_q : torch .Tensor | None = None ,
209+ cu_seqlens_k : torch .Tensor | None = None ,
210+ max_seq_len_q : int | None = None ,
211+ max_seq_len_k : int | None = None ,
212+ softmax_scale : float | None = None ,
213+ causal : bool = False ,
214+ window_left : int = - 1 ,
215+ window_right : int = - 1 ,
216+ bottom_right : bool = True ,
217+ ) -> tuple [torch .Tensor , torch .Tensor ]:
218+ """
219+ Decode-optimized forward pass using the gen kernel.
220+ This is a wrapper to use the gen kernel which is optimized
221+ for decode (query length = 1).
222+
223+ This function is called externally by xformers ops.
224+
225+ Accepts inputs in two formats:
226+ - Varlen format: [total_queries, num_heads, head_dim] (3D)
227+ - Batch format: [batch_size, 1, num_heads, head_dim] (4D)
228+ """
229+ _validate_decode_inputs (q , k , v , seqlen_kv )
230+ # Handle window size for causal attention
231+ if causal and window_left >= 0 :
232+ window_right = 0
233+
234+ # Prepare inputs and handle format conversion
235+ q , k , v , batch_size , needs_reshape_output , original_shape = _prepare_decode_inputs (
236+ q , k , v
237+ )
238+
239+ # Create batch_idx tensor
240+ batch_idx = torch .arange (batch_size , dtype = torch .int32 , device = q .device )
241+
242+ # Call the gen kernel (optimized for decode)
243+ out = torch .ops .fbgemm .fmha_gen_fwd (
155244 q ,
156245 k ,
157246 v ,
158247 seqlen_kv ,
159248 batch_idx ,
160- kernel_type ,
249+ kernel_type = GenKernelType .UMMA_I ,
250+ # window_left=window_left,
251+ # window_right=window_right,
161252 )
162253
254+ # Reshape output back to original format if needed
255+ if needs_reshape_output :
256+ out = out .view (* original_shape )
257+
258+ # Create dummy LSE for compatibility
259+ lse = _create_decode_lse (out , batch_size , needs_reshape_output , original_shape )
260+
261+ return out , lse
262+
163263
164264class CutlassBlackwellFmhaFunc (torch .autograd .Function ):
165265 @staticmethod
@@ -181,70 +281,68 @@ def forward( # type: ignore
181281 bottom_right : bool = True ,
182282 deterministic : bool = False ,
183283 ) -> torch .Tensor :
284+ window_left , window_right = window_size
184285 # Check if this is generation phase (sq = 1)
185286 sq = q .shape [1 ]
186- # Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
187- if cu_seqlens_q is not None and cu_seqlens_k is not None :
188- assert (
189- cu_seqlens_q .dtype == torch .int32
190- and cu_seqlens_q .dtype == cu_seqlens_k .dtype
191- ), "cu_seqlens_q and cu_seqlens_k must be int32"
192-
193- # handle window_size
194- window_left , window_right = window_size
195- if causal and window_left >= 0 :
196- window_right = 0
197-
198287 if q .dim () == 4 and sq == 1 :
199- batch_size = q .shape [0 ]
200-
201- # Use provided seqlen_kv
202- assert (
203- seqlen_kv is not None
204- ), "seqlen_kv must be provided for generation phase"
205-
206- # Create batch_idx tensor
207- batch_idx = torch .arange (batch_size , dtype = torch .int32 , device = q .device )
208-
209- # Use gen forward (no backward needed for generation)
210- out = _cutlass_blackwell_fmha_gen (
211- q , k , v , seqlen_kv , batch_idx , kernel_type = GenKernelType .UMMA_I
212- )
213288 # For gen case, we don't need to save tensors for backward
214289 ctx .is_gen = True
215- return out
216- else :
217- # Use regular FMHA for non-generation case
218- out , softmax_lse = _cutlass_blackwell_fmha_forward (
290+ out , _ = cutlass_blackwell_fmha_decode_forward (
219291 q ,
220292 k ,
221293 v ,
294+ seqlen_kv ,
222295 cu_seqlens_q ,
223296 cu_seqlens_k ,
224297 max_seq_len_q ,
225298 max_seq_len_k ,
226299 softmax_scale ,
227300 causal ,
228- seqlen_kv ,
229- page_table ,
230- seqlen_k ,
231301 window_left ,
232302 window_right ,
233303 bottom_right ,
234304 )
235- ctx .save_for_backward (q , k , v , out , softmax_lse )
236- ctx .softmax_scale = softmax_scale
237- ctx .causal = causal
238- ctx .window_size = window_size
239- ctx .max_seq_len_q = max_seq_len_q
240- ctx .max_seq_len_k = max_seq_len_k
241- ctx .cu_seqlens_q = cu_seqlens_q
242- ctx .cu_seqlens_k = cu_seqlens_k
243- ctx .is_gen = False
244- ctx .bottom_right = bottom_right
245- ctx .deterministic = deterministic
246305 return out
247306
307+ ctx .is_gen = False
308+ # Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
309+ if cu_seqlens_q is not None and cu_seqlens_k is not None :
310+ assert (
311+ cu_seqlens_q .dtype == torch .int32
312+ and cu_seqlens_q .dtype == cu_seqlens_k .dtype
313+ ), "cu_seqlens_q and cu_seqlens_k must be int32"
314+
315+ # handle window_size
316+ if causal and window_left >= 0 :
317+ window_right = 0
318+ # Use regular FMHA for non-generation case
319+ out , softmax_lse = _cutlass_blackwell_fmha_forward (
320+ q ,
321+ k ,
322+ v ,
323+ cu_seqlens_q ,
324+ cu_seqlens_k ,
325+ max_seq_len_q ,
326+ max_seq_len_k ,
327+ softmax_scale ,
328+ causal ,
329+ seqlen_kv ,
330+ window_left ,
331+ window_right ,
332+ bottom_right ,
333+ )
334+ ctx .save_for_backward (q , k , v , out , softmax_lse )
335+ ctx .softmax_scale = softmax_scale
336+ ctx .causal = causal
337+ ctx .window_size = window_size
338+ ctx .max_seq_len_q = max_seq_len_q
339+ ctx .max_seq_len_k = max_seq_len_k
340+ ctx .cu_seqlens_q = cu_seqlens_q
341+ ctx .cu_seqlens_k = cu_seqlens_k
342+ ctx .bottom_right = bottom_right
343+ ctx .deterministic = deterministic
344+ return out
345+
248346 @staticmethod
249347 def backward (ctx , dout : torch .Tensor , * args : Any ) -> tuple [ # type: ignore
250348 torch .Tensor ,
0 commit comments