Skip to content

Commit 71a6fdc

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Kernel interface for use in xformers (#5063)
Summary: X-link: facebookresearch/FBGEMM#2071 This diff introduces a kernel interface for use in xformers, specifically for the Blackwell Decode . xformers support in D84630701 Reviewed By: sryap, jianyuh Differential Revision: D84835511
1 parent 34754ea commit 71a6fdc

File tree

1 file changed

+158
-60
lines changed

1 file changed

+158
-60
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

Lines changed: 158 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -133,33 +133,133 @@ def _cutlass_blackwell_fmha_backward(
133133
)
134134

135135

136-
def _cutlass_blackwell_fmha_gen(
136+
def _validate_decode_inputs(
137137
q: torch.Tensor,
138138
k: torch.Tensor,
139139
v: torch.Tensor,
140-
seqlen_kv: torch.Tensor,
141-
batch_idx: torch.Tensor,
142-
kernel_type: GenKernelType = GenKernelType.UMMA_I,
140+
seqlen_kv: torch.Tensor | None,
141+
) -> None:
142+
assert seqlen_kv is not None, "seqlen_kv must be provided for decode"
143+
tensors = {"q": q, "k": k, "v": v, "seqlen_kv": seqlen_kv}
144+
145+
for name, tensor in tensors.items():
146+
# assert tensor.is_contiguous(), f"{name} is not contiguous"
147+
assert tensor.is_cuda, f"{name} must be on GPU"
148+
149+
150+
def _prepare_decode_inputs(
151+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
152+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool, tuple[int, ...]]:
153+
"""
154+
Prepare inputs for decode kernel by handling both varlen and batch formats.
155+
156+
Returns:
157+
- Reshaped q, k, v tensors in batch format [B, 1, H, D]
158+
- batch_size
159+
- needs_reshape_output flag
160+
- original_shape of q
161+
"""
162+
original_shape = tuple(q.shape)
163+
needs_reshape_output = False
164+
batch_size = q.shape[0]
165+
166+
if q.dim() == 3:
167+
# Varlen format: [total_queries, num_heads, head_dim]
168+
q = q.view(batch_size, 1, q.shape[1], q.shape[2])
169+
needs_reshape_output = True
170+
171+
if q.dim() != 4:
172+
raise ValueError(
173+
f"Invalid query shape: {q.shape}. Expected [B, 1, H, D] or [total_queries, H, D]"
174+
)
175+
assert q.shape[1] == 1, "Kernel have sq=1"
176+
177+
k = k.view(batch_size, -1, k.shape[1], k.shape[2]) if k.dim() == 3 else k
178+
v = v.view(batch_size, -1, v.shape[1], v.shape[2]) if v.dim() == 3 else v
179+
180+
return q, k, v, batch_size, needs_reshape_output, original_shape
181+
182+
183+
def _create_decode_lse(
184+
out: torch.Tensor,
185+
batch_size: int,
186+
needs_reshape_output: bool,
187+
q_shape: tuple[int, ...],
143188
) -> torch.Tensor:
144-
assert q.is_contiguous(), "q is not contiguous"
145-
assert k.is_contiguous(), "k is not contiguous"
146-
assert v.is_contiguous(), "v is not contiguous"
147-
assert seqlen_kv.is_contiguous(), "seqlen_kv is not contiguous"
148-
assert batch_idx.is_contiguous(), "batch_idx is not contiguous"
149-
assert q.is_cuda, "q must be on GPU"
150-
assert k.is_cuda, "k must be on GPU"
151-
assert v.is_cuda, "v must be on GPU"
152-
assert seqlen_kv.is_cuda, "seqlen_kv must be on GPU"
153-
assert batch_idx.is_cuda, "batch_idx must be on GPU"
154-
return torch.ops.fbgemm.fmha_gen_fwd(
189+
"""
190+
Create dummy LSE tensor for decode output compatibility.
191+
Gen kernel doesn't return LSE, so we create a zero tensor.
192+
"""
193+
if needs_reshape_output:
194+
# For varlen output format
195+
lse_shape = [batch_size, q_shape[-1]] # [B, H]
196+
else:
197+
# For batch output format
198+
lse_shape = [batch_size, q_shape[-2], q_shape[1]] # [B, H, 1]
199+
200+
return torch.zeros(*lse_shape, dtype=torch.float32, device=out.device)
201+
202+
203+
def cutlass_blackwell_fmha_decode_forward(
204+
q: torch.Tensor,
205+
k: torch.Tensor,
206+
v: torch.Tensor,
207+
seqlen_kv: torch.Tensor | None = None,
208+
cu_seqlens_q: torch.Tensor | None = None,
209+
cu_seqlens_k: torch.Tensor | None = None,
210+
max_seq_len_q: int | None = None,
211+
max_seq_len_k: int | None = None,
212+
softmax_scale: float | None = None,
213+
causal: bool = False,
214+
window_left: int = -1,
215+
window_right: int = -1,
216+
bottom_right: bool = True,
217+
) -> tuple[torch.Tensor, torch.Tensor]:
218+
"""
219+
Decode-optimized forward pass using the gen kernel.
220+
This is a wrapper to use the gen kernel which is optimized
221+
for decode (query length = 1).
222+
223+
This function is called externally by xformers ops.
224+
225+
Accepts inputs in two formats:
226+
- Varlen format: [total_queries, num_heads, head_dim] (3D)
227+
- Batch format: [batch_size, 1, num_heads, head_dim] (4D)
228+
"""
229+
_validate_decode_inputs(q, k, v, seqlen_kv)
230+
# Handle window size for causal attention
231+
if causal and window_left >= 0:
232+
window_right = 0
233+
234+
# Prepare inputs and handle format conversion
235+
q, k, v, batch_size, needs_reshape_output, original_shape = _prepare_decode_inputs(
236+
q, k, v
237+
)
238+
239+
# Create batch_idx tensor
240+
batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device)
241+
242+
# Call the gen kernel (optimized for decode)
243+
out = torch.ops.fbgemm.fmha_gen_fwd(
155244
q,
156245
k,
157246
v,
158247
seqlen_kv,
159248
batch_idx,
160-
kernel_type,
249+
kernel_type=GenKernelType.UMMA_I,
250+
# window_left=window_left,
251+
# window_right=window_right,
161252
)
162253

254+
# Reshape output back to original format if needed
255+
if needs_reshape_output:
256+
out = out.view(*original_shape)
257+
258+
# Create dummy LSE for compatibility
259+
lse = _create_decode_lse(out, batch_size, needs_reshape_output, original_shape)
260+
261+
return out, lse
262+
163263

164264
class CutlassBlackwellFmhaFunc(torch.autograd.Function):
165265
@staticmethod
@@ -181,70 +281,68 @@ def forward( # type: ignore
181281
bottom_right: bool = True,
182282
deterministic: bool = False,
183283
) -> torch.Tensor:
284+
window_left, window_right = window_size
184285
# Check if this is generation phase (sq = 1)
185286
sq = q.shape[1]
186-
# Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
187-
if cu_seqlens_q is not None and cu_seqlens_k is not None:
188-
assert (
189-
cu_seqlens_q.dtype == torch.int32
190-
and cu_seqlens_q.dtype == cu_seqlens_k.dtype
191-
), "cu_seqlens_q and cu_seqlens_k must be int32"
192-
193-
# handle window_size
194-
window_left, window_right = window_size
195-
if causal and window_left >= 0:
196-
window_right = 0
197-
198287
if q.dim() == 4 and sq == 1:
199-
batch_size = q.shape[0]
200-
201-
# Use provided seqlen_kv
202-
assert (
203-
seqlen_kv is not None
204-
), "seqlen_kv must be provided for generation phase"
205-
206-
# Create batch_idx tensor
207-
batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device)
208-
209-
# Use gen forward (no backward needed for generation)
210-
out = _cutlass_blackwell_fmha_gen(
211-
q, k, v, seqlen_kv, batch_idx, kernel_type=GenKernelType.UMMA_I
212-
)
213288
# For gen case, we don't need to save tensors for backward
214289
ctx.is_gen = True
215-
return out
216-
else:
217-
# Use regular FMHA for non-generation case
218-
out, softmax_lse = _cutlass_blackwell_fmha_forward(
290+
out, _ = cutlass_blackwell_fmha_decode_forward(
219291
q,
220292
k,
221293
v,
294+
seqlen_kv,
222295
cu_seqlens_q,
223296
cu_seqlens_k,
224297
max_seq_len_q,
225298
max_seq_len_k,
226299
softmax_scale,
227300
causal,
228-
seqlen_kv,
229-
page_table,
230-
seqlen_k,
231301
window_left,
232302
window_right,
233303
bottom_right,
234304
)
235-
ctx.save_for_backward(q, k, v, out, softmax_lse)
236-
ctx.softmax_scale = softmax_scale
237-
ctx.causal = causal
238-
ctx.window_size = window_size
239-
ctx.max_seq_len_q = max_seq_len_q
240-
ctx.max_seq_len_k = max_seq_len_k
241-
ctx.cu_seqlens_q = cu_seqlens_q
242-
ctx.cu_seqlens_k = cu_seqlens_k
243-
ctx.is_gen = False
244-
ctx.bottom_right = bottom_right
245-
ctx.deterministic = deterministic
246305
return out
247306

307+
ctx.is_gen = False
308+
# Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
309+
if cu_seqlens_q is not None and cu_seqlens_k is not None:
310+
assert (
311+
cu_seqlens_q.dtype == torch.int32
312+
and cu_seqlens_q.dtype == cu_seqlens_k.dtype
313+
), "cu_seqlens_q and cu_seqlens_k must be int32"
314+
315+
# handle window_size
316+
if causal and window_left >= 0:
317+
window_right = 0
318+
# Use regular FMHA for non-generation case
319+
out, softmax_lse = _cutlass_blackwell_fmha_forward(
320+
q,
321+
k,
322+
v,
323+
cu_seqlens_q,
324+
cu_seqlens_k,
325+
max_seq_len_q,
326+
max_seq_len_k,
327+
softmax_scale,
328+
causal,
329+
seqlen_kv,
330+
window_left,
331+
window_right,
332+
bottom_right,
333+
)
334+
ctx.save_for_backward(q, k, v, out, softmax_lse)
335+
ctx.softmax_scale = softmax_scale
336+
ctx.causal = causal
337+
ctx.window_size = window_size
338+
ctx.max_seq_len_q = max_seq_len_q
339+
ctx.max_seq_len_k = max_seq_len_k
340+
ctx.cu_seqlens_q = cu_seqlens_q
341+
ctx.cu_seqlens_k = cu_seqlens_k
342+
ctx.bottom_right = bottom_right
343+
ctx.deterministic = deterministic
344+
return out
345+
248346
@staticmethod
249347
def backward(ctx, dout: torch.Tensor, *args: Any) -> tuple[ # type: ignore
250348
torch.Tensor,

0 commit comments

Comments
 (0)