Skip to content

Commit 4fd07a2

Browse files
authored
Improve fmha fp8 perf (#1555)
* Separate dqk and dv * update ck * update ck * Simplify the default argument * Add perf test
1 parent 2ae0991 commit 4fd07a2

File tree

3 files changed

+137
-32
lines changed

3 files changed

+137
-32
lines changed

3rdparty/composable_kernel

Submodule composable_kernel updated 549 files

op_tests/test_mha_fp8.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import torch
55
import aiter
66
from aiter import dtypes
7+
from aiter.test_common import run_perftest
78
from aiter import per_tensor_quant
89
import pytest
10+
import pandas as pd
911
import argparse
1012

13+
benchmark = {}
14+
1115

1216
def run_ck(
1317
q,
@@ -20,7 +24,8 @@ def run_ck(
2024
v_descale=None,
2125
):
2226
if q.dtype == dtypes.fp8 and k.dtype == dtypes.fp8 and v.dtype == dtypes.fp8:
23-
return aiter.flash_attn_fp8_pertensor_func(
27+
return run_perftest(
28+
aiter.flash_attn_fp8_pertensor_func,
2429
q,
2530
k,
2631
v,
@@ -31,7 +36,8 @@ def run_ck(
3136
window_size=window_size,
3237
)
3338
else:
34-
return aiter.flash_attn_func(
39+
return run_perftest(
40+
aiter.flash_attn_func,
3541
q,
3642
k,
3743
v,
@@ -112,7 +118,7 @@ def test_flash_attn_output(
112118
k_quant, k_descale = per_tensor_quant(k, quant_dtype=quant_dtype)
113119
v_quant, v_descale = per_tensor_quant(v, quant_dtype=quant_dtype)
114120

115-
out = run_ck(
121+
out, us_quant_fwd = run_ck(
116122
q_quant,
117123
k_quant,
118124
v_quant,
@@ -122,12 +128,41 @@ def test_flash_attn_output(
122128
k_descale,
123129
v_descale,
124130
)
125-
out_ref = run_ck(q, k, v, causal, window_size)
131+
out_ref, us_fwd = run_ck(q, k, v, causal, window_size)
126132

127133
max_diff = (out - out_ref).abs().max().item()
128134
print(f"Output max diff: {max_diff}")
129135
assert max_diff < 0.055
130136

137+
fwd_flop = (
138+
batch_size
139+
* nheads
140+
* (seqlen_q * seqlen_k * d * 2 + seqlen_q * seqlen_k * d_v * 2)
141+
)
142+
143+
dtype_bytes = torch.finfo(dtype).bits // 8
144+
quant_dtype_bytes = torch.finfo(quant_dtype).bits // 8
145+
146+
fwd_num_bytes = (
147+
batch_size
148+
* nheads
149+
* dtype_bytes
150+
* (seqlen_q * d + seqlen_k * d + seqlen_k * d_v + seqlen_q * d_v)
151+
)
152+
quant_fwd_num_bytes = (
153+
batch_size
154+
* nheads
155+
* quant_dtype_bytes
156+
* (seqlen_q * d + seqlen_k * d + seqlen_k * d_v + seqlen_q * d_v)
157+
)
158+
159+
benchmark["quant_fwd_us"] = us_quant_fwd
160+
benchmark["quant_fwd_tflops"] = (fwd_flop) / 1.0e6 / us_quant_fwd
161+
benchmark["quant_fwd_gb_per_sec"] = (quant_fwd_num_bytes) / 1.0e3 / us_quant_fwd
162+
benchmark["fwd_us"] = us_fwd
163+
benchmark["fwd_tflops"] = (fwd_flop) / 1.0e6 / us_fwd
164+
benchmark["fwd_gb_per_sec"] = (fwd_num_bytes) / 1.0e3 / us_fwd
165+
131166

132167
parser = argparse.ArgumentParser(
133168
formatter_class=argparse.RawTextHelpFormatter,
@@ -153,8 +188,8 @@ def test_flash_attn_output(
153188
"-nk",
154189
"--nheads_k",
155190
type=int,
156-
default=5,
157-
help="""Number of heads. Default is 5.
191+
default=-1,
192+
help="""Number of heads. -1 means equal to n (nheads).
158193
e.g.: -nk 1""",
159194
)
160195
parser.add_argument(
@@ -169,18 +204,26 @@ def test_flash_attn_output(
169204
"-k",
170205
"--seqlen_k",
171206
type=int,
172-
default=512,
173-
help="""Sequence length for key. Default is 512.
207+
default=-1,
208+
help="""Sequence length for key. -1 means equal to q (seqlen_q).
174209
e.g.: -k 1024""",
175210
)
176211
parser.add_argument(
177212
"-d",
178-
"--d_qkv",
213+
"--d_qk",
179214
type=int,
180215
default=128,
181216
help="""Dimension of query and key. Default is 128.
182217
e.g.: -d 128""",
183218
)
219+
parser.add_argument(
220+
"-dv",
221+
"--d_v",
222+
type=int,
223+
default=-1,
224+
help="""Dimension of query and key. -1 means equal to d (d_qk).
225+
e.g.: -dv 128""",
226+
)
184227
parser.add_argument(
185228
"-c",
186229
"--causal",
@@ -198,14 +241,24 @@ def test_flash_attn_output(
198241

199242
if __name__ == "__main__":
200243
args = parser.parse_args()
244+
245+
nheads_k = args.nheads_k if args.nheads_k > 0 else args.nheads
246+
seqlen_k = args.seqlen_k if args.seqlen_k > 0 else args.seqlen_q
247+
d_v = args.d_v if args.d_v > 0 else args.d_qk
248+
249+
collected = []
201250
test_flash_attn_output(
202251
args.batch_size,
203252
args.nheads,
204-
args.nheads_k,
253+
nheads_k,
205254
args.seqlen_q,
206-
args.seqlen_k,
207-
args.d_qkv,
208-
args.d_qkv,
255+
seqlen_k,
256+
args.d_qk,
257+
d_v,
209258
args.causal,
210259
args.local,
211260
)
261+
collected.append(benchmark)
262+
263+
df = pd.DataFrame(collected)
264+
aiter.logger.info(f"mha summary:\n{df}")

op_tests/test_mha_varlen_fp8.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44
import torch
55
import aiter
66
from aiter import dtypes
7+
from aiter.test_common import run_perftest
78
from aiter import per_tensor_quant
89
from aiter.test_mha_common import (
9-
attention_ref,
10-
attn_bias_from_alibi_slopes,
11-
ck_randval_to_dropout_mask,
12-
convert_flash_attn_S_to_softmax,
1310
generate_qkv,
1411
generate_random_padding_mask,
15-
pad_rearrange_dropout_mask_hts_to_bhss,
1612
)
1713
import pytest
14+
import pandas as pd
1815
import argparse
1916

17+
benchmark = {}
18+
2019

2120
def run_ck(
2221
q,
@@ -34,7 +33,8 @@ def run_ck(
3433
v_descale=None,
3534
):
3635
if q.dtype == dtypes.fp8 and k.dtype == dtypes.fp8 and v.dtype == dtypes.fp8:
37-
return aiter.flash_attn_varlen_fp8_pertensor_func(
36+
return run_perftest(
37+
aiter.flash_attn_varlen_fp8_pertensor_func,
3838
q,
3939
k,
4040
v,
@@ -50,7 +50,8 @@ def run_ck(
5050
window_size=window_size,
5151
)
5252
else:
53-
return aiter.flash_attn_varlen_func(
53+
return run_perftest(
54+
aiter.flash_attn_varlen_func,
5455
q,
5556
k,
5657
v,
@@ -167,7 +168,7 @@ def test_flash_attn_varlen_output(
167168
k_quant, k_descale = per_tensor_quant(k, quant_dtype=quant_dtype)
168169
v_quant, v_descale = per_tensor_quant(v, quant_dtype=quant_dtype)
169170

170-
out = run_ck(
171+
out, us_quant_fwd = run_ck(
171172
q_quant,
172173
k_quant,
173174
v_quant,
@@ -183,7 +184,7 @@ def test_flash_attn_varlen_output(
183184
v_descale,
184185
)
185186

186-
out_ref = run_ck(
187+
out_ref, us_fwd = run_ck(
187188
q,
188189
k,
189190
v,
@@ -200,6 +201,39 @@ def test_flash_attn_varlen_output(
200201
print(f"Output max diff: {max_diff}")
201202
assert max_diff < 0.055
202203

204+
fwd_flop = 0
205+
dtype_bytes = torch.finfo(dtype).bits // 8
206+
quant_dtype_bytes = torch.finfo(quant_dtype).bits // 8
207+
fwd_num_bytes = 0
208+
quant_fwd_num_bytes = 0
209+
for i in range(len(cu_seqlens_q) - 1):
210+
real_seqlen_q = cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item()
211+
real_seqlen_k = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
212+
fwd_flop = (
213+
fwd_flop
214+
+ nheads * 2 * real_seqlen_q * real_seqlen_k * d
215+
+ nheads * 2 * real_seqlen_q * real_seqlen_k * d_v
216+
)
217+
fwd_num_bytes = fwd_num_bytes + nheads * dtype_bytes * (
218+
real_seqlen_q * d
219+
+ real_seqlen_k * d
220+
+ real_seqlen_k * d_v
221+
+ real_seqlen_q * d_v
222+
)
223+
quant_fwd_num_bytes = fwd_num_bytes + nheads * quant_dtype_bytes * (
224+
real_seqlen_q * d
225+
+ real_seqlen_k * d
226+
+ real_seqlen_k * d_v
227+
+ real_seqlen_q * d_v
228+
)
229+
230+
benchmark["quant_fwd_us"] = us_quant_fwd
231+
benchmark["quant_fwd_tflops"] = (fwd_flop) / 1.0e6 / us_quant_fwd
232+
benchmark["quant_fwd_gb_per_sec"] = (quant_fwd_num_bytes) / 1.0e3 / us_quant_fwd
233+
benchmark["fwd_us"] = us_fwd
234+
benchmark["fwd_tflops"] = (fwd_flop) / 1.0e6 / us_fwd
235+
benchmark["fwd_gb_per_sec"] = (fwd_num_bytes) / 1.0e3 / us_fwd
236+
203237

204238
parser = argparse.ArgumentParser(
205239
formatter_class=argparse.RawTextHelpFormatter,
@@ -225,8 +259,8 @@ def test_flash_attn_varlen_output(
225259
"-nk",
226260
"--nheads_k",
227261
type=int,
228-
default=5,
229-
help="""Number of heads. Default is 5.
262+
default=-1,
263+
help="""Number of heads. -1 means equal to n (nheads).
230264
e.g.: -nk 1""",
231265
)
232266
parser.add_argument(
@@ -241,18 +275,26 @@ def test_flash_attn_varlen_output(
241275
"-k",
242276
"--seqlen_k",
243277
type=int,
244-
default=512,
245-
help="""Sequence length for key. Default is 512.
278+
default=-1,
279+
help="""Sequence length for key. -1 means equal to q (seqlen_q).
246280
e.g.: -k 1024""",
247281
)
248282
parser.add_argument(
249283
"-d",
250-
"--d_qkv",
284+
"--d_qk",
251285
type=int,
252286
default=128,
253287
help="""Dimension of query and key. Default is 128.
254288
e.g.: -d 128""",
255289
)
290+
parser.add_argument(
291+
"-dv",
292+
"--d_v",
293+
type=int,
294+
default=-1,
295+
help="""Dimension of query and key. -1 means equal to d (d_qk).
296+
e.g.: -dv 128""",
297+
)
256298
parser.add_argument(
257299
"-msq",
258300
"--min_seqlen_q",
@@ -279,15 +321,25 @@ def test_flash_attn_varlen_output(
279321

280322
if __name__ == "__main__":
281323
args = parser.parse_args()
324+
325+
nheads_k = args.nheads_k if args.nheads_k > 0 else args.nheads
326+
seqlen_k = args.seqlen_k if args.seqlen_k > 0 else args.seqlen_q
327+
d_v = args.d_v if args.d_v > 0 else args.d_qk
328+
329+
collected = []
282330
test_flash_attn_varlen_output(
283331
args.batch_size,
284332
args.nheads,
285-
args.nheads_k,
333+
nheads_k,
286334
args.seqlen_q,
287-
args.seqlen_k,
288-
args.d_qkv,
289-
args.d_qkv,
335+
seqlen_k,
336+
args.d_qk,
337+
d_v,
290338
args.min_seqlen_q,
291339
args.causal,
292340
args.local,
293341
)
342+
collected.append(benchmark)
343+
344+
df = pd.DataFrame(collected)
345+
aiter.logger.info(f"mha summary:\n{df}")

0 commit comments

Comments
 (0)