@@ -38,6 +38,7 @@ def get_xqa_module(
3838 head_dim : int ,
3939 head_group_ratio : int ,
4040 use_sliding_window : bool ,
41+ output_dtype : torch .dtype ,
4142):
4243 module = gen_xqa_module (
4344 input_dtype ,
@@ -46,10 +47,11 @@ def get_xqa_module(
4647 head_dim ,
4748 head_group_ratio ,
4849 use_sliding_window ,
50+ output_dtype ,
4951 ).build_and_load ()
5052
5153 @register_custom_op (
52- f"flashinfer::xqa_input_{ filename_safe_dtype_map [input_dtype ]} _kv_cache_{ filename_safe_dtype_map [kv_cache_dtype ]} _page_size_{ page_size } _head_dim_{ head_dim } _head_group_ratio_{ head_group_ratio } _use_sliding_window_{ use_sliding_window } " ,
54+ f"flashinfer::xqa_input_{ filename_safe_dtype_map [input_dtype ]} _kv_cache_{ filename_safe_dtype_map [kv_cache_dtype ]} _output_ { filename_safe_dtype_map [ output_dtype ] } _page_size_{ page_size } _head_dim_{ head_dim } _head_group_ratio_{ head_group_ratio } _use_sliding_window_{ use_sliding_window } " ,
5355 mutates_args = ("output" , "workspace_buffer" ),
5456 )
5557 def xqa (
@@ -59,6 +61,7 @@ def xqa(
5961 sliding_win_size : int ,
6062 q_scale : float ,
6163 output : torch .Tensor ,
64+ rcp_out_scale : float ,
6265 q : torch .Tensor ,
6366 sinks : Optional [torch .Tensor ],
6467 k_cache : torch .Tensor ,
@@ -79,6 +82,7 @@ def xqa(
7982 sliding_win_size ,
8083 q_scale ,
8184 output ,
85+ rcp_out_scale ,
8286 q ,
8387 sinks ,
8488 k_cache ,
@@ -94,7 +98,7 @@ def xqa(
9498 )
9599
96100 @register_fake_op (
97- f"flashinfer::xqa_input_{ filename_safe_dtype_map [input_dtype ]} _kv_cache_{ filename_safe_dtype_map [kv_cache_dtype ]} _page_size_{ page_size } _head_dim_{ head_dim } _head_group_ratio_{ head_group_ratio } _use_sliding_window_{ use_sliding_window } "
101+ f"flashinfer::xqa_input_{ filename_safe_dtype_map [input_dtype ]} _kv_cache_{ filename_safe_dtype_map [kv_cache_dtype ]} _output_ { filename_safe_dtype_map [ output_dtype ] } _page_size_{ page_size } _head_dim_{ head_dim } _head_group_ratio_{ head_group_ratio } _use_sliding_window_{ use_sliding_window } "
98102 )
99103 def _fake_xqa (
100104 run_sm90_fp8_mha : bool ,
@@ -103,6 +107,7 @@ def _fake_xqa(
103107 sliding_win_size : int ,
104108 q_scale : float ,
105109 output : torch .Tensor ,
110+ rcp_out_scale : float ,
106111 q : torch .Tensor ,
107112 sinks : Optional [torch .Tensor ],
108113 k_cache : torch .Tensor ,
@@ -140,6 +145,7 @@ def xqa(
140145 kv_layout : str = "NHD" ,
141146 sm_count : Optional [int ] = None ,
142147 enable_pdl : Optional [bool ] = None ,
148+ rcp_out_scale : float = 1.0 ,
143149) -> None :
144150 r"""Apply attention with paged KV cache using XQA kernel.
145151 Parameters
@@ -167,7 +173,7 @@ def xqa(
167173 Data type should be torch.uint32.
168174 output : torch.Tensor
169175 Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
170- Data type should match query tensor. This tensor will be modified in-place.
176+ Data type should match query tensor or kv tensor . This tensor will be modified in-place.
171177 workspace_buffer : torch.Tensor
172178 Workspace buffer for temporary computations.
173179 Data type should be torch.uint8.
@@ -196,6 +202,8 @@ def xqa(
196202 enable_pdl : Optional[bool], default=None
197203 Whether to enable PDL (Persistent Data Loader) optimization.
198204 If None, will be set to True if hardware supports it.
205+ rcp_out_scale : float, default=1.0
206+ Reciprocal of output scale factor.
199207
200208 Note
201209 ----
@@ -231,6 +239,13 @@ def xqa(
231239
232240 assert k_cache .dtype == v_cache .dtype , "K and V cache must have the same dtype"
233241
242+ if output .dtype == torch .float8_e4m3fn :
243+ assert k_cache .dtype == torch .float8_e4m3fn , (
244+ "KV cache must be fp8 when output is fp8"
245+ )
246+ else :
247+ assert output .dtype == q .dtype , "Output and query must have the same dtype"
248+
234249 # Convert HND layout to NHD if necessary (transpose only changes stride, not data)
235250 if kv_layout == "HND" :
236251 # For HND: [..., H, N, D] -> NHD: [..., N, H, D]
@@ -255,6 +270,7 @@ def xqa(
255270 head_dim ,
256271 head_group_ratio ,
257272 use_sliding_window ,
273+ output .dtype ,
258274 )
259275 xqa_module .xqa (
260276 run_sm90_fp8_mha ,
@@ -263,6 +279,7 @@ def xqa(
263279 sliding_win_size if use_sliding_window else 0 ,
264280 q_scale ,
265281 output ,
282+ rcp_out_scale ,
266283 q ,
267284 sinks ,
268285 k_cache ,
0 commit comments