diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 603ce5ecf0d2..6ddab4621457 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -259,6 +259,7 @@ def write_results_to_csv(results, filename=None): # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 40903c6c3444..131df74c7de1 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -274,6 +274,7 @@ def write_results_to_csv(results, filename=None): quant_dtypes = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 8d0a11d8eb8a..bd3ba554b32e 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -35,6 +35,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] @@ -44,6 +45,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): HEAD_SIZE = [128] KV_LAYOUT = ["HND"] # currently only HND is supported BLOCK_SIZE = [16] +WINDOW_LEFT = [-1, 127] SOFT_CAP = [None, 50.0] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @@ -57,6 +59,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( @@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline( head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") @@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline( sm_scale=sm_scale, q_data_type=dtype, kv_data_type=dtype, + window_left=window_left, logits_soft_cap=soft_cap) output = torch.empty(ref_query.shape, dtype=dtype) @@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + window_left=window_left, o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline( @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", [None]) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( @@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") @@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( sm_scale=sm_scale, q_data_type=dtype, kv_data_type=dtype, + window_left=window_left, logits_soft_cap=soft_cap) output = torch.empty(ref_query.shape, dtype=dtype) @@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + window_left=window_left, o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline( rtol, atol = 4e-1, 1e0 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: + rtol, atol = 4e-2, 6e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 3095f17110fd..43c345695ef4 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -258,8 +258,10 @@ def __init__(self, config: VllmConfig): pattern_fp8 = AttentionFp8StaticQuantPattern(layer) pattern_fp8.register_if_supported(self.patterns) - pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) - pattern_nvfp4.register_if_supported(self.patterns) + if current_platform.is_cuda() and hasattr(torch.ops._C, + "scaled_fp4_quant"): + pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4.register_if_supported(self.patterns) if len(attn_layers) == 0: logger.warning( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 06a853007a57..c7a565810b45 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -194,19 +194,15 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size - self.enable_fusion = ( - self.compilation_config.pass_config.enable_attn_fusion) - self.q_data_type = self.model_config.dtype self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): self.kv_cache_dtype = ( FlashInferBackend.get_fp8_dtype_for_flashinfer( self.cache_dtype)) - # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled - if self.enable_fusion: - self.q_data_type = self.kv_cache_dtype else: + assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype + self.q_data_type = self.kv_cache_dtype self._cascade_wrapper = None # Wrapper for cascade attention @@ -668,8 +664,6 @@ def forward( # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert attn_metadata.q_data_type != FP8_DTYPE, \ - "Query can only be FP8 if output fusion happened." assert output_block_scale is None, "output_block_scale "\ "is not supported when fusion has not happened" else: @@ -697,7 +691,8 @@ def forward( elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query + # Insert FP8 quant for query + if attn_metadata.q_data_type == FP8_DTYPE: num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( query.reshape(