Skip to content

Commit eccbdde

Browse files
authored
minor: canonicalize TFLOPS calculation (#2069)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * TFLOPS computation standardized across attention benchmarks so reported performance metrics consistently account for actual sequence and batch lengths. * **Bug Fixes** * Added checks to prevent invalid mixed-length causal inputs, avoiding misleading benchmark results. * **Chores** * Renamed timing parameter in the benchmark utility for clearer intent. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 11177e8 commit eccbdde

File tree

5 files changed

+87
-39
lines changed

5 files changed

+87
-39
lines changed

benchmarks/bench_blackwell_attention.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import torch
1919

2020
import flashinfer
21-
from flashinfer.testing.utils import bench_gpu_time
21+
from flashinfer.testing.utils import (
22+
bench_gpu_time,
23+
attention_tflops_per_sec_with_actual_seq_lens,
24+
)
2225

2326

2427
def bench_fmha_blackwell(
@@ -69,14 +72,17 @@ def bench_fmha_blackwell(
6972
)
7073
ms = np.median(measurements)
7174

72-
def flops(ms):
73-
if causal:
74-
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
75-
else:
76-
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9
77-
75+
TFLOPS = attention_tflops_per_sec_with_actual_seq_lens(
76+
torch.full((batch_size,), qkv_len),
77+
torch.full((batch_size,), qkv_len),
78+
head_dim,
79+
head_dim,
80+
num_heads,
81+
causal,
82+
ms,
83+
)
7884
print(
79-
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s"
85+
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s"
8086
)
8187

8288

benchmarks/bench_block_sparse_attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import torch
1919

2020
import flashinfer
21-
from flashinfer.testing.utils import bench_gpu_time
21+
from flashinfer.testing.utils import (
22+
bench_gpu_time,
23+
attention_tflops_per_sec_with_actual_seq_lens,
24+
)
2225

2326

2427
def bench_variable_block_sparse_attention(
@@ -120,7 +123,15 @@ def bench_variable_block_sparse_attention(
120123
)
121124

122125
def flops(ms):
123-
return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
126+
return attention_tflops_per_sec_with_actual_seq_lens(
127+
torch.tensor([seq_len]),
128+
torch.tensor([seq_len]),
129+
head_dim,
130+
head_dim,
131+
num_qo_heads,
132+
False,
133+
ms,
134+
)
124135

125136
print(
126137
f"bench_variable_block_sparse_attention (num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, seq_len={seq_len}, num_blocks_row={num_blocks_row}, num_blocks_col={num_blocks_col}, block_density={block_density}), sparse fa2-template: {flops(sparse_ms_fa2):.3f} TFLOPs/s, sparse fa3-template: {flops(sparse_ms_fa3):.3f} TFLOPs/s, dense fa2-template: {flops(dense_sm80_ms):.3f} TFLOPs/s, dense fa3-template: {flops(dense_sm90_ms):.3f} TFLOPs/s"

benchmarks/bench_hopper_attention.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import torch
1919

2020
import flashinfer
21-
from flashinfer.testing.utils import bench_gpu_time
21+
from flashinfer.testing.utils import (
22+
bench_gpu_time,
23+
attention_tflops_per_sec_with_actual_seq_lens,
24+
)
2225

2326

2427
def bench_single_prefill(seq_len, num_heads, causal, head_dim):
@@ -41,10 +44,15 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim):
4144
)
4245

4346
def flops(ms):
44-
if causal:
45-
return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
46-
else:
47-
return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
47+
return attention_tflops_per_sec_with_actual_seq_lens(
48+
torch.tensor([seq_len]),
49+
torch.tensor([seq_len]),
50+
head_dim,
51+
head_dim,
52+
num_qo_heads,
53+
causal,
54+
ms,
55+
)
4856

4957
print(
5058
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s"
@@ -97,14 +105,15 @@ def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim)
97105
)
98106

99107
def flops(ms):
100-
if causal:
101-
return (
102-
batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
103-
)
104-
else:
105-
return (
106-
batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
107-
)
108+
return attention_tflops_per_sec_with_actual_seq_lens(
109+
torch.full((batch_size,), seq_len),
110+
torch.full((batch_size,), seq_len),
111+
head_dim,
112+
head_dim,
113+
num_qo_heads,
114+
causal,
115+
ms,
116+
)
108117

109118
print(
110119
f"bench_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s"
@@ -176,14 +185,15 @@ def bench_batch_paged_prefill(
176185
)
177186

178187
def flops(ms):
179-
if causal:
180-
return (
181-
batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
182-
)
183-
else:
184-
return (
185-
batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
186-
)
188+
return attention_tflops_per_sec_with_actual_seq_lens(
189+
torch.full((batch_size,), seq_len),
190+
torch.full((batch_size,), seq_len),
191+
head_dim,
192+
head_dim,
193+
num_qo_heads,
194+
causal,
195+
ms,
196+
)
187197

188198
print(
189199
f"bench_batch_paged_prefill (page_size={page_size} batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s"

benchmarks/bench_hopper_fp8_attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import torch
33

44
import flashinfer
5-
from flashinfer.testing.utils import bench_gpu_time
5+
from flashinfer.testing.utils import (
6+
bench_gpu_time,
7+
attention_tflops_per_sec_with_actual_seq_lens,
8+
)
69

710

811
def bench_single_prefill(seq_len, num_heads, causal, head_dim):
@@ -45,10 +48,15 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim):
4548
)
4649

4750
def flops(ms):
48-
if causal:
49-
return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
50-
else:
51-
return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
51+
return attention_tflops_per_sec_with_actual_seq_lens(
52+
torch.tensor([seq_len]),
53+
torch.tensor([seq_len]),
54+
head_dim,
55+
head_dim,
56+
num_qo_heads,
57+
causal,
58+
ms,
59+
)
5260

5361
print(
5462
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s"

flashinfer/testing/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ def attention_flops(
277277
Returns:
278278
total_flops (int): Total FLOPs for the layer.
279279
"""
280+
# Causal attention requires kv_len >= q_len
281+
if qo_seqlen > kv_seqlen:
282+
raise ValueError(
283+
"qo_seqlen must be less than or equal to kv_seqlen for causal attention"
284+
)
285+
280286
if causal:
281287
bmm1_flops = (
282288
batch_size
@@ -323,6 +329,13 @@ def attention_flops_with_actual_seq_lens(
323329
Returns:
324330
total_flops (int): Total FLOPs for the layer.
325331
"""
332+
# Causal attention requires kv_len >= q_len
333+
# Otherwise right align if kv_len > q_len
334+
if causal and (actual_seq_lens_q > actual_seq_lens_kv).any():
335+
raise ValueError(
336+
"actual_seq_lens_q must be less than or equal to actual_seq_lens_kv for causal attention"
337+
)
338+
326339
if causal:
327340
bmm1_flops = (
328341
torch.dot(
@@ -412,7 +425,7 @@ def attention_tflops_per_sec_with_actual_seq_lens(
412425
head_dim_vo,
413426
num_qo_heads,
414427
causal,
415-
time,
428+
ms,
416429
):
417430
"""
418431
Calculate TFLOPS per second for a given attention layer with actual sequence lengths.
@@ -425,7 +438,7 @@ def attention_tflops_per_sec_with_actual_seq_lens(
425438
head_dim_vo (int): Head dimension of the value.
426439
num_qo_heads (int): Number of query heads.
427440
causal (bool): Whether to use causal masking.
428-
time (float): Execution time in milliseconds.
441+
ms (float): Execution time in milliseconds.
429442
430443
Returns:
431444
tflops_per_sec (float): TFLOPS per second for the layer.
@@ -438,7 +451,7 @@ def attention_tflops_per_sec_with_actual_seq_lens(
438451
num_qo_heads,
439452
causal,
440453
)
441-
return f.item() / time / 1e9 if not math.isnan(time) else 0.0
454+
return f.item() / ms / 1e9 if not math.isnan(ms) else 0.0
442455

443456

444457
def attention_tb_per_sec(

0 commit comments

Comments
 (0)