diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 6df6ed4da..3423a707d 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -833,7 +833,7 @@ "'-DCK_TILE_FMHA_FWD_FAST_EXP2=1'", "f'-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get(\"CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT\", 2)}'", "f'-DCK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT={os.environ.get(\"CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT\", 0)}'", - "f'-DCK_TILE_ATTENTION_USE_SOFTSIGN_ASM={os.environ.get(\"CK_TILE_ATTENTION_USE_SOFTSIGN_ASM\", 1)}'" + "'-DCK_TILE_ATTENTION_USE_SOFTSIGN_ASM=0'" ], "extra_ldflags": "None", "extra_include": [ @@ -970,7 +970,7 @@ "blob_gen_cmd": [ "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd --receipt 600 --output_dir {{}}'", "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --receipt 600 --output_dir {{}}'", - "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill --receipt 600 --output_dir {{}}'", + "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill --receipt 600 --optdim 128 --filter \"fmha_batch_prefill_d128_bf16_*_ndropout_*_nbias_*_nmask_*_nlse_*\" --output_dir {{}}'", "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 5 --output_dir {{}}'", "f'{get_asm_dir()}/fmha_v3_fwd/codegen.py --output_dir {{}}'" ] diff --git a/setup.py b/setup.py index 167417a4d..092c0c231 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ # !!!!!!!!!!!!!!!! never import aiter # from aiter.jit import core this_dir = os.path.dirname(os.path.abspath(__file__)) +os.environ["AITER_META_DIR"] = this_dir sys.path.insert(0, f"{this_dir}/aiter/") from concurrent.futures import ThreadPoolExecutor @@ -73,13 +74,6 @@ def is_develop_mode(): ck_dir ), 'CK is needed by aiter, please make sure clone by "git clone --recursive https://github.com/ROCm/aiter.git" or "git submodule sync ; git submodule update --init --recursive"' - if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"): - shutil.rmtree("aiter_meta") - shutil.copytree("3rdparty", "aiter_meta/3rdparty") - shutil.copytree("hsa", "aiter_meta/hsa") - shutil.copytree("gradlib", "aiter_meta/gradlib") - shutil.copytree("csrc", "aiter_meta/csrc") - def get_exclude_ops(): if PREBUILD_KERNELS == 1: return [ @@ -174,6 +168,8 @@ def get_exclude_ops(): "module_mla_metadata", "module_mla_reduce", ] + elif PREBUILD_KERNELS == 4: + return [] else: return [] @@ -225,7 +221,13 @@ def build_one_module(one_opt_args): raise NotImplementedError("Only ROCM is supported") -# aiter_meta prepared above +if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"): + shutil.rmtree("aiter_meta") +## link "3rdparty", "hsa", "csrc" into "aiter_meta" +shutil.copytree("3rdparty", "aiter_meta/3rdparty") +shutil.copytree("hsa", "aiter_meta/hsa") +shutil.copytree("gradlib", "aiter_meta/gradlib") +shutil.copytree("csrc", "aiter_meta/csrc") class NinjaBuildExtension(BuildExtension): @@ -293,4 +295,4 @@ def has_ext_modules(self): ) if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"): - shutil.rmtree("aiter_meta") + shutil.rmtree("aiter_meta") \ No newline at end of file