Skip to content

Commit 53a6da4

Browse files
authored
enable xqa fp8 output (#2081)
1 parent 96e73b8 commit 53a6da4

File tree

12 files changed

+86
-35
lines changed

12 files changed

+86
-35
lines changed

csrc/flashinfer_xqa_binding.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);
2727
#else
2828

2929
void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads,
30-
int64_t slidingWinSize, double qScale, TensorView output,
31-
#if LOW_PREC_OUTPUT
32-
TensorView rcpOutScale,
33-
#endif
30+
int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale,
3431
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
3532
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
3633
TensorView seqLen, int64_t batchSize, double kvCacheScale,

csrc/xqa/mha.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ CUBIN_EXPORT __global__
12811281
float qScale,
12821282
OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads]
12831283
#if LOW_PREC_OUTPUT
1284-
float const* rcpOutScale,
1284+
float rcpOutScale,
12851285
#endif
12861286
// NOTE: the input is actually Q buffer when integrated to TRT-LLM.
12871287
IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads],
@@ -2165,7 +2165,7 @@ CUBIN_EXPORT __global__
21652165
}
21662166
ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum);
21672167
#if LOW_PREC_OUTPUT
2168-
voScale *= rcpOutScale[0];
2168+
voScale *= rcpOutScale;
21692169
#endif
21702170
rescaleAcc(warp, acc, fullRescaleMask, rcpRowSum * ThrdRegRowMax::filled(voScale));
21712171
}
@@ -2396,7 +2396,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
23962396
float qScale,
23972397
OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads]
23982398
#if LOW_PREC_OUTPUT
2399-
float const* rcpOutScale,
2399+
float rcpOutScale,
24002400
#endif
24012401
IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads],
24022402
#if SPEC_DEC
@@ -2447,7 +2447,7 @@ void launchMHA(
24472447
#endif
24482448
float qScale, OutputHead* output,
24492449
#if LOW_PREC_OUTPUT
2450-
float const* rcpOutScale,
2450+
float rcpOutScale,
24512451
#endif
24522452
#if USE_INPUT_KV
24532453
InputHead const* qkv,
@@ -2563,7 +2563,7 @@ static uint32_t const hostSmemSize = configureKernel();
25632563
void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
25642564
float qScale, OutputHead* output,
25652565
#if LOW_PREC_OUTPUT
2566-
float const* rcpOutScale,
2566+
float rcpOutScale,
25672567
#endif
25682568
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
25692569
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,

csrc/xqa/mha.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ void launchMHA(
9595
#endif
9696
float qScale, OutputHead* output,
9797
#if LOW_PREC_OUTPUT
98-
float const* rcpOutScale,
98+
float rcpOutScale,
9999
#endif
100100
#if USE_INPUT_KV
101101
InputHead const* qkv,
@@ -125,7 +125,7 @@ void launchMHA(
125125
void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
126126
float qScale, OutputHead* output,
127127
#if LOW_PREC_OUTPUT
128-
float const* rcpOutScale,
128+
float rcpOutScale,
129129
#endif
130130
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
131131
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,
@@ -145,7 +145,7 @@ void launchHopperF8MHA(
145145
#endif
146146
float qScale, OutputHead* output,
147147
#if LOW_PREC_OUTPUT
148-
float const* rcpOutScale,
148+
float rcpOutScale,
149149
#endif
150150
#if USE_INPUT_KV
151151
InputHead const* qkv,
@@ -174,7 +174,7 @@ void launchHopperF8MHA(
174174
void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads,
175175
uint32_t slidingWinSize, float qScale, OutputHead* output,
176176
#if LOW_PREC_OUTPUT
177-
float const* rcpOutScale,
177+
float rcpOutScale,
178178
#endif
179179
InputHead const* q, float const* attentionSinks,
180180
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,

csrc/xqa/mha_sm90.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ __launch_bounds__(128 * 3)
610610
float const qScale,
611611
OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads]
612612
#if LOW_PREC_OUTPUT
613-
float const* const rcpOutScale,
613+
float rcpOutScale,
614614
#endif
615615
#if USE_INPUT_KV
616616
IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads],
@@ -957,7 +957,7 @@ __launch_bounds__(128 * 3)
957957

958958
constexpr float xScale = 1.f / kE4M3_MAX;
959959
#if LOW_PREC_OUTPUT
960-
float const oScale = rcpOutScale[0];
960+
float const oScale = rcpOutScale;
961961
#else
962962
constexpr float oScale = 1.F;
963963
#endif
@@ -2910,7 +2910,7 @@ void launchHopperF8MHA(
29102910
#endif
29112911
float qScale, OutputHead* output,
29122912
#if LOW_PREC_OUTPUT
2913-
float const* rcpOutScale,
2913+
float rcpOutScale,
29142914
#endif
29152915
#if USE_INPUT_KV
29162916
InputHead const* qkv,
@@ -3037,7 +3037,7 @@ static uint32_t const hostSmemSize = configureKernel();
30373037
void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads,
30383038
uint32_t slidingWinSize, float qScale, OutputHead* output,
30393039
#if LOW_PREC_OUTPUT
3040-
float const* rcpOutScale,
3040+
float rcpOutScale,
30413041
#endif
30423042
InputHead const* q, float const* attentionSinks,
30433043
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,

csrc/xqa/xqa_wrapper.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
4545
#else
4646

4747
void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads,
48-
int64_t slidingWinSize, double qScale, TensorView output,
49-
#if LOW_PREC_OUTPUT
50-
TensorView rcpOutScale,
51-
#endif
48+
int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale,
5249
TensorView q, Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
5350
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
5451
TensorView seqLen, int64_t batchSize, double kvCacheScale,
@@ -70,7 +67,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
7067
mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
7168
reinterpret_cast<OutputHead*>(output.data_ptr()),
7269
#if LOW_PREC_OUTPUT
73-
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
70+
rcpOutScale,
7471
#endif
7572
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
7673
reinterpret_cast<GMemCacheHead*>(kCacheVLLM.data_ptr()),

flashinfer/aot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def gen_xqa(
404404
head_dim=head_size,
405405
head_group_ratio=head_grp_size,
406406
use_sliding_window=use_sliding_window,
407+
output_dtype=input_type,
407408
)
408409

409410
if has_sm120 or has_sm121:

flashinfer/decode.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2077,6 +2077,7 @@ def trtllm_batch_decode_with_kv_cache(
20772077
enable_pdl: Optional[bool] = None,
20782078
backend: str = "auto",
20792079
q_len_per_req: Optional[int] = 1,
2080+
o_scale: Optional[float] = 1.0,
20802081
) -> Union[torch.Tensor, FP4Tensor]:
20812082
"""
20822083
Parameters
@@ -2142,6 +2143,9 @@ def trtllm_batch_decode_with_kv_cache(
21422143
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
21432144
For sm_90 (hopper architecture) and sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
21442145
2146+
o_scale : Optional[float] = 1.0
2147+
output scale factor for xqa fp8 output.
2148+
21452149
Returns
21462150
-------
21472151
out : Union[torch.Tensor, FP4Tensor]
@@ -2196,6 +2200,7 @@ def trtllm_batch_decode_with_kv_cache(
21962200
kv_layout=kv_layout,
21972201
enable_pdl=enable_pdl,
21982202
q_len_per_req=q_len_per_req,
2203+
o_scale=o_scale,
21992204
)
22002205
elif backend == "trtllm-gen":
22012206
# Convert NHD layout to HND if necessary (transpose only changes stride, not data)
@@ -2340,6 +2345,7 @@ def xqa_batch_decode_with_kv_cache(
23402345
kv_layout: str = "NHD",
23412346
enable_pdl: bool = None,
23422347
q_len_per_req: Optional[int] = 1,
2348+
o_scale: Optional[float] = 1.0,
23432349
) -> torch.Tensor:
23442350
"""
23452351
Parameters
@@ -2388,6 +2394,9 @@ def xqa_batch_decode_with_kv_cache(
23882394
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
23892395
Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode.
23902396
2397+
o_scale : Optional[float] = 1.0
2398+
output scale factor for fp8 output.
2399+
23912400
Returns
23922401
-------
23932402
out : torch.Tensor
@@ -2434,7 +2443,7 @@ def xqa_batch_decode_with_kv_cache(
24342443
workspace_u8 = workspace_buffer.view(torch.uint8)
24352444
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
24362445
scratch = workspace_u8[8 * 1024 * 1024 :]
2437-
kv_scale_value = bmm2_scale
2446+
kv_scale_value = bmm2_scale * o_scale
24382447
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
24392448

24402449
query_new = query.unsqueeze(1)
@@ -2464,6 +2473,7 @@ def xqa_batch_decode_with_kv_cache(
24642473
kv_layout=kv_layout,
24652474
sm_count=sm_count,
24662475
enable_pdl=enable_pdl,
2476+
rcp_out_scale=1.0 / o_scale,
24672477
)
24682478

24692479
return out

flashinfer/jit/xqa.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"-DBEAM_WIDTH=1",
2929
"-DUSE_INPUT_KV=0",
3030
"-DUSE_CUSTOM_BARRIER=1",
31-
"-DLOW_PREC_OUTPUT=0",
3231
"-DSPEC_DEC=0",
3332
]
3433

@@ -40,6 +39,7 @@ def gen_xqa_module(
4039
head_dim: int,
4140
head_group_ratio: int,
4241
use_sliding_window: bool,
42+
output_dtype: torch.dtype,
4343
) -> JitSpec:
4444
if input_dtype == torch.float16:
4545
flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
@@ -76,6 +76,11 @@ def gen_xqa_module(
7676
else:
7777
flag_sliding_window = ["-DSLIDING_WINDOW=0"]
7878

79+
if output_dtype == torch.float8_e4m3fn:
80+
flag_low_prec_output = ["-DLOW_PREC_OUTPUT=1"]
81+
else:
82+
flag_low_prec_output = ["-DLOW_PREC_OUTPUT=0"]
83+
7984
compilation_context = CompilationContext()
8085
nvcc_flags = compilation_context.get_nvcc_flags_list(
8186
supported_major_versions=[9, 10, 11, 12]
@@ -85,7 +90,7 @@ def gen_xqa_module(
8590
flag_mla_wrapper = ["-DMLA_WRAPPER=0"]
8691

8792
return gen_jit_spec(
88-
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
93+
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
8994
[
9095
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
9196
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu",
@@ -101,6 +106,7 @@ def gen_xqa_module(
101106
+ flag_kv_cache_dtype
102107
+ flag_head_group_ratio
103108
+ flag_sliding_window
109+
+ flag_low_prec_output
104110
+ flag_mla_wrapper,
105111
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
106112
)

flashinfer/xqa.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_xqa_module(
3838
head_dim: int,
3939
head_group_ratio: int,
4040
use_sliding_window: bool,
41+
output_dtype: torch.dtype,
4142
):
4243
module = gen_xqa_module(
4344
input_dtype,
@@ -46,10 +47,11 @@ def get_xqa_module(
4647
head_dim,
4748
head_group_ratio,
4849
use_sliding_window,
50+
output_dtype,
4951
).build_and_load()
5052

5153
@register_custom_op(
52-
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
54+
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
5355
mutates_args=("output", "workspace_buffer"),
5456
)
5557
def xqa(
@@ -59,6 +61,7 @@ def xqa(
5961
sliding_win_size: int,
6062
q_scale: float,
6163
output: torch.Tensor,
64+
rcp_out_scale: float,
6265
q: torch.Tensor,
6366
sinks: Optional[torch.Tensor],
6467
k_cache: torch.Tensor,
@@ -79,6 +82,7 @@ def xqa(
7982
sliding_win_size,
8083
q_scale,
8184
output,
85+
rcp_out_scale,
8286
q,
8387
sinks,
8488
k_cache,
@@ -94,7 +98,7 @@ def xqa(
9498
)
9599

96100
@register_fake_op(
97-
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
101+
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
98102
)
99103
def _fake_xqa(
100104
run_sm90_fp8_mha: bool,
@@ -103,6 +107,7 @@ def _fake_xqa(
103107
sliding_win_size: int,
104108
q_scale: float,
105109
output: torch.Tensor,
110+
rcp_out_scale: float,
106111
q: torch.Tensor,
107112
sinks: Optional[torch.Tensor],
108113
k_cache: torch.Tensor,
@@ -140,6 +145,7 @@ def xqa(
140145
kv_layout: str = "NHD",
141146
sm_count: Optional[int] = None,
142147
enable_pdl: Optional[bool] = None,
148+
rcp_out_scale: float = 1.0,
143149
) -> None:
144150
r"""Apply attention with paged KV cache using XQA kernel.
145151
Parameters
@@ -167,7 +173,7 @@ def xqa(
167173
Data type should be torch.uint32.
168174
output : torch.Tensor
169175
Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
170-
Data type should match query tensor. This tensor will be modified in-place.
176+
Data type should match query tensor or kv tensor. This tensor will be modified in-place.
171177
workspace_buffer : torch.Tensor
172178
Workspace buffer for temporary computations.
173179
Data type should be torch.uint8.
@@ -196,6 +202,8 @@ def xqa(
196202
enable_pdl : Optional[bool], default=None
197203
Whether to enable PDL (Persistent Data Loader) optimization.
198204
If None, will be set to True if hardware supports it.
205+
rcp_out_scale : float, default=1.0
206+
Reciprocal of output scale factor.
199207
200208
Note
201209
----
@@ -231,6 +239,13 @@ def xqa(
231239

232240
assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype"
233241

242+
if output.dtype == torch.float8_e4m3fn:
243+
assert k_cache.dtype == torch.float8_e4m3fn, (
244+
"KV cache must be fp8 when output is fp8"
245+
)
246+
else:
247+
assert output.dtype == q.dtype, "Output and query must have the same dtype"
248+
234249
# Convert HND layout to NHD if necessary (transpose only changes stride, not data)
235250
if kv_layout == "HND":
236251
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
@@ -255,6 +270,7 @@ def xqa(
255270
head_dim,
256271
head_group_ratio,
257272
use_sliding_window,
273+
output.dtype,
258274
)
259275
xqa_module.xqa(
260276
run_sm90_fp8_mha,
@@ -263,6 +279,7 @@ def xqa(
263279
sliding_win_size if use_sliding_window else 0,
264280
q_scale,
265281
output,
282+
rcp_out_scale,
266283
q,
267284
sinks,
268285
k_cache,

0 commit comments

Comments
 (0)