|
1 | 1 | --- aiter/jit/optCompilerConfig.json |
2 | 2 | +++ aiter/jit/optCompilerConfig.json |
3 | | -@@ -619,6 +619,7 @@ |
4 | | - "verbose": "False", |
| 3 | +@@ -699,7 +699,7 @@ |
5 | 4 | "hip_clang_path": "os.environ.get('MHA_HIP_CLANG_PATH')", |
6 | 5 | "blob_gen_cmd": [ |
7 | | -+ "f'{get_asm_dir()}/fmha_v3_fwd/codegen.py --output_dir {{}}'", |
8 | 6 | "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd --receipt 600 --output_dir {{}}'", |
9 | | - "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}'" |
| 7 | +- "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}'" |
| 8 | ++ "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 2 --output_dir {{}}'" |
10 | 9 | ] |
11 | | - |
12 | | ---- csrc/cpp_itfs/mha_fwd_generate.py |
13 | | -+++ csrc/cpp_itfs/mha_fwd_generate.py |
14 | | -@@ -150,7 +150,7 @@ COMBINED_API = """t = fmha_fwd_v3(traits, args, stream_config); |
15 | | - API_MAP = { |
16 | | - 1: FMHA_FWD_API.format(F_inner_dispatch=V3_API), |
17 | | - 2: FMHA_FWD_API.format(F_inner_dispatch=V2_API), |
18 | | -- 3: FMHA_FWD_API.format(F_inner_dispatch=V2_API) + FMHA_FWD_SPLITKV_API, |
19 | | -+ 3: FMHA_FWD_API.format(F_inner_dispatch=COMBINED_API), |
20 | | - 4: FMHA_BATCH_PREFILL_API, |
21 | | - 5: FMHA_FWD_API.format(F_inner_dispatch=COMBINED_API) |
22 | | - + FMHA_FWD_SPLITKV_API |
| 10 | + }, |
| 11 | + "module_mha_varlen_fwd": { |
23 | 12 |
|
24 | 13 | --- csrc/py_itfs_cu/asm_pa.cu |
25 | 14 | +++ csrc/py_itfs_cu/asm_pa.cu |
|
89 | 78 | from . import mla |
90 | 79 | +from . import paged_attn |
91 | 80 |
|
| 81 | + |
92 | 82 | --- aiter/ops/gemm_op_a8w8.py |
93 | 83 | +++ aiter/ops/gemm_op_a8w8.py |
94 | 84 | @@ -425,9 +425,11 @@ |
|
0 commit comments