diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index f4605c08f..548604b2b 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -976,7 +976,7 @@ ] }, "libmha_bwd": { - "srcs": [], + "srcs": ["f'{AITER_CSRC_DIR}/cpp_itfs/fmha_v3_bwd.cpp'"], "flags_extra_cc": [], "flags_extra_hip": [ "f'-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get(\"CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT\", 0)}'" @@ -990,10 +990,7 @@ "is_standalone": "False", "torch_exclude": "True", "blob_gen_cmd": [ - "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd --receipt 600 --output_dir {{}}'", - "f'{AITER_CSRC_DIR}/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py --filter \"*@*_ndeterministic@*_nbias*_dropout*_ndeterministic*\" --output_dir {{}}'", - "f'{get_asm_dir()}/fmha_v3_bwd/codegen.py --output_dir {{}}'", - "f'{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 3 --output_dir {{}}'" + "f'{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd --receipt 600 --output_dir {{}}'" ] }, "module_rocsolgemm": { diff --git a/csrc/cpp_itfs/fmha_v3_bwd.cpp b/csrc/cpp_itfs/fmha_v3_bwd.cpp new file mode 100644 index 000000000..14bd48fe1 --- /dev/null +++ b/csrc/cpp_itfs/fmha_v3_bwd.cpp @@ -0,0 +1,538 @@ +#include "asm_fmha_v3_bwd_configs.hpp" +#include "aiter_hip_common.h" +#include "fmha_bwd.hpp" +#include "mha_bwd.h" +#include + +namespace aiter { +std::tuple get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id) +{ + if(hdim_q == 192 && hdim_v == 128 && arch_id == "gfx950") + return std::make_tuple(hdim_q, hdim_v); + assert(hdim_q == hdim_v); + if(hdim_q <= 64) + { + return std::make_tuple(64, 64); + } + else if(hdim_q <= 128) + { + return std::make_tuple(128, 128); + } + else if(hdim_q <= 192) + { + return std::make_tuple(192, 192); + } + + assert(false); + return std::make_tuple(hdim_q, hdim_v); +} + +std::tuple get_heuristic_kernel(std::string data_type, + std::string arch_id, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mask_type, + bool atomic32, + int bf16_cvt, + bool mode, + CFG* pre_cfgs, + CFG* cfgs, + CFG* post_cfgs) +{ + auto [padded_hdim_q, padded_hdim_v] = get_padded_hdim(hdim_q, hdim_v, arch_id); + int pddv = (padded_hdim_q != hdim_q) || (padded_hdim_v != hdim_v); + int pssk; + int ts_kv = 0; + + // std::cout << "padded_hdim_q: " << padded_hdim_q << ", padded_hdim_v: " << padded_hdim_v << std::endl; + // std::cout << "pddv: " << pddv << std::endl; + + std::string preProcessingKernelName = ""; + std::string dQdKdVKernelName = ""; + std::string postProcessingKernelName = ""; + + for(const auto& el : *pre_cfgs) + { + if(el.first.find(arch_id) != 0) + continue; + const auto& cfg = el.second; + + if((cfg.dtype == data_type) && (cfg.hdim_q == padded_hdim_q) && (cfg.mode == mode)) + { + preProcessingKernelName = el.first; + break; + } + } + + for(const auto& el : *cfgs) + { + if(el.first.find(arch_id) != 0) { + // std::cout << "not supported arch" << std::endl; + continue; + } + const auto& cfg = el.second; + + if((cfg.dtype == data_type) && (cfg.hdim_q == padded_hdim_q) && (cfg.hdim_v == padded_hdim_v) && + (cfg.mask == mask_type) && (cfg.atomic32 == atomic32) && + ((arch_id == "gfx950") || ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt))) && (cfg.mode == mode)) + { + // std::cout << "kernel name: " << el.first << std::endl; + if(ts_kv == 0) + { + // std::cout << "first chosen ts_kv: " << cfg.ts << std::endl; + ts_kv= cfg.ts; + pssk = (seqlen_q != seqlen_k) || (seqlen_q % ts_kv != 0); + // std::cout << "pssk: " << pssk << std::endl; + } + if((cfg.pssk == pssk) && (cfg.pddv == pddv)) + { + // std::cout << "all matched case: " << el.first << std::endl; + dQdKdVKernelName = el.first; + break; + } + else if((cfg.pssk >= pssk) && (cfg.pddv >= pddv)) + { + // std::cout << "some matched case: " << el.first << std::endl; + dQdKdVKernelName = el.first; + } + } + } + + if (!post_cfgs) { + return std::make_tuple(preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName); + } + + for(const auto& el : *post_cfgs) + { + if(el.first.find(arch_id) != 0) + continue; + const auto& cfg = el.second; + + if((cfg.dtype == data_type) && (cfg.hdim_q == padded_hdim_q) && (cfg.mode == mode) && + ((arch_id == "gfx950") || ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt)))) + { + postProcessingKernelName = el.first; + break; + } + } + return std::make_tuple(preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName); +} + +float mha_bwd(mha_bwd_args a) +{ + if (a.use_asm_v3 == 1) { + return fmha_v3_bwd(a); + } + // else { + // fmha_bwd_traits traits{ + // a.hdim_q, + // a.hdim_v, + // a.data_type, + // a.is_group_mode, + // static_cast(a.ck_mask_type), + // static_cast(a.bias_type), + // a.has_dbias, + // a.has_dropout, + // a.is_store_randval, + // a.is_deterministic + // }; + + // fmha_bwd_args ck_args{ + // /* q_ptr */ a.q_ptr, + // /* k_ptr */ a.k_ptr, + // /* v_ptr */ a.v_ptr, + // /* bias_ptr */ a.bias_ptr, + // /* o_ptr */ a.o_ptr, + // /* lse_ptr */ a.lse_ptr, + // /* do_ptr */ a.do_ptr, + // /* d_ptr */ a.d_ptr, + // /* rand_val_ptr */ a.rand_val_ptr, + // /* dq_ptr */ a.dq_ptr, + // /* dk_ptr */ a.dk_ptr, + // /* dv_ptr */ a.dv_ptr, + // /* dbias_ptr */ a.dbias_ptr, + // /* dq_acc_ptr */ a.dq_acc_ptr, + + // /* seqstart_q_ptr */ a.seqstart_q_ptr, + // /* seqstart_k_ptr */ a.seqstart_k_ptr, + // /* seqlen_q_ptr */ a.seqlen_q_ptr, + // /* seqlen_k_ptr */ a.seqlen_k_ptr, + // /* cu_seqlen_q_ptr */ a.cu_seqlen_q_ptr, + // /* cu_seqlen_k_ptr */ a.cu_seqlen_k_ptr, + + // /* seqlen_q */ a.seqlen_q, + // /* seqlen_k */ a.seqlen_k, + // /* batch */ a.batch, + // /* max_seqlen_q */ a.max_seqlen_q, + // /* max_seqlen_k */ a.max_seqlen_k, + // /* hdim_q */ a.hdim_q, + // /* hdim_v */ a.hdim_v, + // /* nhead_q */ a.nhead_q, + // /* nhead_k */ a.nhead_k, + // /* scale */ a.scale, + + // /* stride_q */ a.stride_q, + // /* stride_k */ a.stride_k, + // /* stride_v */ a.stride_v, + // /* stride_bias */ a.stride_bias, + // /* stride_o */ a.stride_o, + // /* stride_randval */ a.stride_randval, + // /* stride_do */ a.stride_do, + // /* stride_dq_acc */ a.stride_dq_acc, + // /* stride_dq */ a.stride_dq, + // /* stride_dk */ a.stride_dk, + // /* stride_dv */ a.stride_dv, + // /* stride_dbias */ a.stride_dbias, + + // /* nhead_stride_q */ a.nhead_stride_q, + // /* nhead_stride_k */ a.nhead_stride_k, + // /* nhead_stride_v */ a.nhead_stride_v, + // /* nhead_stride_bias */ a.nhead_stride_bias, + // /* nhead_stride_o */ a.nhead_stride_o, + // /* nhead_stride_randval*/ a.nhead_stride_randval, + // /* nhead_stride_do */ a.nhead_stride_do, + // /* nhead_stride_lsed */ a.nhead_stride_lsed, + // /* nhead_stride_dq_acc*/ a.nhead_stride_dq_acc, + // /* nhead_stride_dq */ a.nhead_stride_dq, + // /* nhead_stride_dk */ a.nhead_stride_dk, + // /* nhead_stride_dv */ a.nhead_stride_dv, + // /* nhead_stride_dbias */ a.nhead_stride_dbias, + + // /* batch_stride_q */ a.batch_stride_q, + // /* batch_stride_k */ a.batch_stride_k, + // /* batch_stride_v */ a.batch_stride_v, + // /* batch_stride_bias */ a.batch_stride_bias, + // /* batch_stride_o */ a.batch_stride_o, + // /* batch_stride_randval*/ a.batch_stride_randval, + // /* batch_stride_do */ a.batch_stride_do, + // /* batch_stride_lsed */ a.batch_stride_lsed, + // /* batch_stride_dq_acc*/ a.batch_stride_dq_acc, + // /* batch_stride_dq */ a.batch_stride_dq, + // /* batch_stride_dk */ a.batch_stride_dk, + // /* batch_stride_dv */ a.batch_stride_dv, + // /* batch_stride_dbias */ a.batch_stride_dbias, + + // /* split_stride_dq_acc*/ a.split_stride_dq_acc, + // /* window_size_left */ a.window_size_left, + // /* window_size_right */ a.window_size_right, + // /* mask_type */ a.ck_mask_type, + // /* p_drop */ a.p_drop, + // /* p_undrop */ a.p_undrop, + // /* drop_seed_offset */ a.drop_seed_offset, + // }; + + // ck_tile::stream_config stream_config{ + // a.stream_id_, a.time_kernel_, a.log_level_, a.cold_niters_, a.nrepeat_, a.is_gpu_timer_, a.flush_cache_, a.rotating_count_ + // }; + + // if (a.use_asm_v3 == 0) { + // float asm_ret = fmha_v3_bwd(a); + // if(asm_ret == -1) + // { + // return fmha_bwd(traits, ck_args, stream_config); + // } + // } else { + // return fmha_bwd(traits, ck_args, stream_config); + // } + // } +} + +float fmha_v3_bwd(mha_bwd_args a) +{ + int ts_odo; + int ts_kv; + int ts_dq; + int arg_size; + + std::string arch_id = get_gpu_arch(); + auto pre_cfgs = &cfg_fmha_bwd_odo; + auto dqdkdv_cfgs = &cfg_fmha_bwd_dqdkdv; + auto post_cfgs = [&]() { + if (arch_id == "gfx950") { + if (a.v3_atomic_fp32) { + return &cfg_fmha_bwd_dq_convert; + } else { + return &cfg_fmha_bwd_dq_shuffle; + } + } else { + if (a.v3_atomic_fp32) { + return &cfg_fmha_bwd_dq_convert; + } else { + return static_cast(nullptr); + } + } + }(); + + AiterAsmKernel* impl_ptr_pre = nullptr; + AiterAsmKernel* impl_ptr_dqdkdv = nullptr; + AiterAsmKernel* impl_ptr_post = nullptr; + + bool need_post_processing = (arch_id == "gfx950") || (a.v3_atomic_fp32 == 1); + + auto [pre_kernel, dqdkdv_kernel, post_kernel] = + get_heuristic_kernel(a.data_type, + arch_id, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.hdim_v, + a.mask_type, + a.v3_atomic_fp32, + a.v3_bf16_cvt, + a.is_group_mode, + pre_cfgs, + dqdkdv_cfgs, + post_cfgs); + + auto it_pre = pre_cfgs->find(pre_kernel); + if(it_pre != pre_cfgs->end()) + { + const auto& cfg = it_pre->second; + const char* name = cfg.knl_name.c_str(); + const char* co_name = cfg.co_name.c_str(); + ts_odo = cfg.ts; + + static AiterAsmKernel impl_kenrel(name, co_name); + impl_ptr_pre = &impl_kenrel; + } + else + { + return -1; + } + + auto it_dqdkdv = dqdkdv_cfgs->find(dqdkdv_kernel); + // std::cout << "it_dqdkdv name: " << dqdkdv_kernel << std::endl; + if(it_dqdkdv != dqdkdv_cfgs->end()) + { + const auto& cfg = it_dqdkdv->second; + const char* name = cfg.knl_name.c_str(); + const char* co_name = cfg.co_name.c_str(); + ts_kv = cfg.ts; + + static AiterAsmKernel impl_kenrel(name, co_name); + impl_ptr_dqdkdv = &impl_kenrel; + } + else + { + // std::cout << "Cannot find dqdkdv kernel: " << dqdkdv_kernel << std::endl; + return -1; + } + + if (!post_cfgs) { + assert((!need_post_processing) && (post_kernel == "")); + } else { + auto it_post = post_cfgs->find(post_kernel); + if(it_post != post_cfgs->end()) + { + const auto& cfg = it_post->second; + const char* name = cfg.knl_name.c_str(); + const char* co_name = cfg.co_name.c_str(); + ts_dq = cfg.ts; + + static AiterAsmKernel impl_kenrel(name, co_name); + impl_ptr_post = &impl_kenrel; + } + else + { + return -1; + } + } + + if (a.v3_api_check) return 1; + + fmha_bwd_odo_args odo_args; + arg_size = sizeof(odo_args); + odo_args.ptr_o = a.o_ptr; + odo_args.ptr_do = a.do_ptr; + odo_args.ptr_d = a.d_ptr; + odo_args.Hs_odo = a.nhead_stride_o * 2; + odo_args.BAs_odo = a.batch_stride_o * 2; + odo_args.Seqs_odo = a.stride_o * 2; + odo_args.Hs_d = a.nhead_stride_lsed * 4; + odo_args.BAs_d = a.batch_stride_lsed * 4; + odo_args.Seqs_d = 1 * 4; + odo_args.seqlen_q = a.seqlen_q; + odo_args.head_dim = a.hdim_q; + odo_args.ptr_qseq_padded = a.seqstart_q_ptr; + odo_args.ptr_qseq = (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) + ? a.cu_seqlen_q_ptr + : a.seqstart_q_ptr; + + auto pre_kernel_launch = + [&]() { + int bdx = 256; + int gdx = (a.max_seqlen_q + ts_odo - 1) / ts_odo; + int gdy = a.nhead_q; + int gdz = a.batch; + + impl_ptr_pre->launch_kernel({&odo_args, + &arg_size, + gdx, + gdy, + gdz, + bdx, + 1, + 1, + a.stream_id_}); + }; + + fmha_bwd_dqdkdv_args dqdkdv_args; + dqdkdv_args.ptr_dq = need_post_processing ? a.dq_acc_ptr : a.dq_ptr; + dqdkdv_args.ptr_dk = a.dk_ptr; + dqdkdv_args.ptr_dv = a.dv_ptr; + dqdkdv_args.ptr_q = a.q_ptr; + dqdkdv_args.ptr_k = a.k_ptr; + dqdkdv_args.ptr_v = a.v_ptr; + dqdkdv_args.ptr_do = a.do_ptr; + dqdkdv_args.ptr_lse = a.lse_ptr; + dqdkdv_args.ptr_d = a.d_ptr; + dqdkdv_args.scalar = a.scale; + dqdkdv_args.log2e = ck_tile::log2e_v; + dqdkdv_args.ratio = a.nhead_q / a.nhead_k; + dqdkdv_args.seqlen_q = a.seqlen_q; + dqdkdv_args.seqlen_k = a.seqlen_k; + dqdkdv_args.head_dim_q = a.hdim_q; + dqdkdv_args.head_dim_v = a.hdim_v; + dqdkdv_args.nhead_q = a.nhead_q; + dqdkdv_args.Ts = ts_kv * a.stride_k * 2; + dqdkdv_args.Hs_q = a.nhead_stride_q * 2; + dqdkdv_args.BAs_q = a.batch_stride_q * 2; + dqdkdv_args.Seqs_q = a.stride_q * 2; + dqdkdv_args.Hs_k = a.nhead_stride_k * 2; + dqdkdv_args.BAs_k = a.batch_stride_k * 2; + dqdkdv_args.Seqs_k = a.stride_k * 2; + dqdkdv_args.Hs_v = a.nhead_stride_v * 2; + dqdkdv_args.BAs_v = a.batch_stride_v * 2; + dqdkdv_args.Seqs_v = a.stride_v * 2; + dqdkdv_args.Hs_do = a.nhead_stride_do * 2; + dqdkdv_args.BAs_do = a.batch_stride_do * 2; + dqdkdv_args.Seqs_do = a.stride_do * 2; + dqdkdv_args.Hs_dk = a.nhead_stride_dk * 2; + dqdkdv_args.BAs_dk = a.batch_stride_dk * 2; + dqdkdv_args.Seqs_dk = a.stride_dk * 2; + dqdkdv_args.Hs_dv = a.nhead_stride_dv * 2; + dqdkdv_args.BAs_dv = a.batch_stride_dv * 2; + dqdkdv_args.Seqs_dv = a.stride_dv * 2; + dqdkdv_args.Hs_lsed = a.nhead_stride_lsed * 4; + + if (a.cu_seqlen_k_ptr && a.seqstart_k_ptr) { + dqdkdv_args.ptr_kseq_padded = a.seqstart_k_ptr; + dqdkdv_args.ptr_kseq = a.cu_seqlen_k_ptr; + } else { + dqdkdv_args.ptr_kseq = a.seqstart_k_ptr; + dqdkdv_args.ptr_kseq_padded = a.seqstart_k_ptr; + } + + if (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) { + dqdkdv_args.ptr_qseq_padded = a.seqstart_q_ptr; + dqdkdv_args.ptr_qseq = a.cu_seqlen_q_ptr; + } else { + dqdkdv_args.ptr_qseq = a.seqstart_q_ptr; + dqdkdv_args.ptr_qseq_padded = a.seqstart_q_ptr; + } + dqdkdv_args.max_seqlen_dq = a.v3_atomic_fp32 + ? a.max_seqlen_q + : (a.max_seqlen_q + 15) / 16 * 16; + // convert l/r to x/y HERE + if (a.window_size_left == -1 && a.window_size_right == 0) { + dqdkdv_args.mask_y = 0; + dqdkdv_args.mask_x = 0; + } else { + auto generic_mask = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + a.window_size_left, a.window_size_right, a.seqlen_q, a.seqlen_k, + (a.mask_type == static_cast(mask_enum::mask_top_left) || + a.mask_type == static_cast(mask_enum::window_generic))); + dqdkdv_args.mask_y = generic_mask.at(ck_tile::number<0>{}); + dqdkdv_args.mask_x = generic_mask.at(ck_tile::number<1>{}); + } + arg_size = sizeof(dqdkdv_args); + auto dqdkdv_kernel_launch = + [&]() { + int bdx = 256; + int gdx = (a.max_seqlen_k + ts_kv - 1) / ts_kv; + int gdy = a.nhead_q; + int gdz = a.batch; + + if(a.mask_type == 1 || a.mask_type == 2) + { // sliding window + gdx = (gdx + 1) / 2; + } + + impl_ptr_dqdkdv->launch_kernel({&dqdkdv_args, + &arg_size, + gdx, + gdy, + gdz, + bdx, + 1, + 1, + a.stream_id_}); + }; + ck_tile::stream_config s{a.stream_id_, a.time_kernel_, a.log_level_, a.cold_niters_, a.nrepeat_, a.is_gpu_timer_, a.flush_cache_, a.rotating_count_}; + + if (!need_post_processing) { + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_) { pre_kernel_launch(); }, + [=](const ck_tile::stream_config& s_) { dqdkdv_kernel_launch(); } + ); + } + + int dq_acc_element_size = a.v3_atomic_fp32? 4: 2; + std::cout << "dq_acc_element_size: " << dq_acc_element_size << std::endl; + fmha_bwd_post_kernel_args post_args; + arg_size = sizeof(post_args); + post_args.ptr_dq_acc = a.dq_acc_ptr; + post_args.ptr_dq = a.dq_ptr; + post_args.Hs_dq_acc = a.nhead_stride_dq_acc * dq_acc_element_size; + post_args.BAs_dq_acc = a.batch_stride_dq_acc * dq_acc_element_size; + post_args.Seqs_dq_acc = a.stride_dq_acc * dq_acc_element_size; + post_args.Hs_dq = a.nhead_stride_dq * 2; + post_args.BAs_dq = a.batch_stride_dq * 2; + post_args.Seqs_dq = a.stride_dq * 2; + post_args.seqlen_q = a.seqlen_q; + post_args.head_dim = a.hdim_q; + post_args.ptr_qseq_padded = a.seqstart_q_ptr; + post_args.ptr_qseq = (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) + ? a.cu_seqlen_q_ptr + : a.seqstart_q_ptr; + std::cout << "Hs_dq_acc: " << post_args.Hs_dq_acc << std::endl + << "BAs_dq_acc: " << post_args.BAs_dq_acc << std::endl + << "Seqs_dq_acc: " << post_args.Seqs_dq_acc << std::endl + << "Hs_dq: " << post_args.Hs_dq << std::endl + << "BAs_dq: " << post_args.BAs_dq << std::endl + << "Seqs_dq: " << post_args.Seqs_dq << std::endl; + + auto post_kernel_launch = + [&]() { + int bdx = 256; + int gdx = (a.max_seqlen_q + ts_dq - 1) / ts_dq; + int gdy = a.nhead_q; + int gdz = a.batch; + + impl_ptr_post->launch_kernel({&post_args, + &arg_size, + gdx, + gdy, + gdz, + bdx, + 1, + 1, + a.stream_id_}); + }; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_) { pre_kernel_launch(); }, + [=](const ck_tile::stream_config& s_) { dqdkdv_kernel_launch(); }, + [=](const ck_tile::stream_config& s_) { post_kernel_launch(); } + ); + // pre_kernel_launch(); + // std::cout << "pre_kernel_launch done" << std::endl; + // dqdkdv_kernel_launch(); + // std::cout << "dqdkdv_kernel_launch done" << std::endl; + // post_kernel_launch(); + // std::cout << "post_kernel_launch done" << std::endl; + // return 1; +} + +} // namespace aiter diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index 79cca4f2e..700d01883 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -5,339 +5,263 @@ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch // headers. #include "aiter_hip_common.h" -#include "fmha_bwd.hpp" -#include "mask.hpp" +#include namespace aiter { -struct mha_bwd_traits : public fmha_bwd_traits -{ - mha_bwd_traits(int head_size_q, - int head_size_v, - std::string dtype, - bool is_group_mode, - mask_enum mask_type, - bias_enum bias_type, - bool has_dbias, - bool has_dropout, - bool is_store_randval, - bool deterministic, - bool use_ext_asm, - bool is_v3_atomic_fp32, - int how_v3_bf16_cvt) - : fmha_bwd_traits{head_size_q, - head_size_v, - dtype, - is_group_mode, - mask_type, - bias_type, - has_dbias, - has_dropout, - is_store_randval, - deterministic}, - use_ext_asm(use_ext_asm), - is_v3_atomic_fp32(is_v3_atomic_fp32), - how_v3_bf16_cvt(how_v3_bf16_cvt) - { - } - bool use_ext_asm; - bool is_v3_atomic_fp32; - int how_v3_bf16_cvt; -}; -using mha_bwd_args = fmha_bwd_args; +struct mha_bwd_args { + // stream config + /* + * construct this structure with behavior as: + * + * // create stream config with default stream(NULL), and not timing the kernel + * stream_config s = stream_config{}; + * + * // create stream config with _some_stream_id_, and not timing the kernel + * stream_config s = stream_config{_some_stream_id_}; + * + * // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default + * stream_config s = stream_config{_some_stream_id_, true}; + * + * // create stream config with _some_stream_id_, and benchmark using cpu timer + * stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false}; + * + * // create stream config with _some_stream_id_, and enable gpu timer for rotating buffer with + *rotating buffer count stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, true, + *true, 1}; + **/ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; + int cold_niters_ = 3; + int nrepeat_ = 10; + bool is_gpu_timer_ = true; // keep compatible + bool flush_cache_ = false; + int rotating_count_ = 1; -// FIXME: use aiter mha_args -__attribute__((visibility("default"))) float mha_bwd(mha_bwd_args args, - const ck_tile::stream_config& stream_config, - std::string q_dtype_str, - bool is_group_mode, - mask_enum mask_type, - bias_enum bias_type, - bool has_dbias, - bool is_store_randval, - bool deterministic, - bool use_ext_asm, - bool is_v3_atomic_fp32, - int how_v3_bf16_cvt, - const void* seqlen_q_padded = nullptr, - const void* seqlen_k_padded = nullptr, - bool is_v3_api_check = false); + // aiter args + int mask_type; // 0: no mask 1: top_left_causal 2: bottom_right_causal 3: sliding_window + bool use_asm_v3; // 0: default(asm first) 1: force asm 2: force ck + bool v3_atomic_fp32; + int v3_bf16_cvt; + bool v3_api_check; -struct __attribute__((packed)) fmha_bwd_v3_args -{ - void* ptr_dq; - p2 _p0; - void* ptr_dk; - p2 _p1; - void* ptr_dv; - p2 _p2; - const void* ptr_q; - p2 _p3; - const void* ptr_k; - p2 _p4; - const void* ptr_v; - p2 _p5; - const void* ptr_do; - p2 _p6; - const void* ptr_lse; - p2 _p7; - const void* ptr_d; - p2 _p8; - float scalar; - p3 _p9; - float log2e; - p3 _p10; - unsigned int seq_len; - p3 _p11; - unsigned int Ts; - p3 _p12; - unsigned int Hs; - p3 _p13; - unsigned int BAs; - p3 _p14; - unsigned int Seqs; - p3 _p15; - unsigned int ratio; - p3 _p16; - unsigned int Hs_kv; - p3 _p17; - unsigned int BAs_kv; - p3 _p18; - unsigned int Seqs_kv; - p3 _p19; - unsigned int Seqs_dkv; - p3 _p20; + // From ck fmha_bwd_traits + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + int ck_mask_type; + int bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + bool is_store_randval; + bool is_deterministic; + + // From ck fmha_bwd_args + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + void* dq_acc_ptr; + // Usage notes for sequence length pointer parameters: + // + // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles + // MQA/GQA..."] + // + // With padding: + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence + // lengths. [array size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each + // sequence. [array size: batch] + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually + // exclusive. Use one set, not both. + // + // Batch mode: + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqstart_* and seqlen_* pointers must be nullptr. + // + // Without padding: + // (Note: Physical length equals logical length) + // + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array + // size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. + // + // Batch mode: + // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. + // + const void* seqstart_q_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqstart_k_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + int seqlen_q; + int seqlen_k; + int batch; + int max_seqlen_q; + int max_seqlen_k; + int nhead_q; + int nhead_k; + float scale; + int stride_q; + int stride_k; + int stride_v; + int stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + int stride_o; + int stride_randval; + int stride_do; + int stride_dq_acc; + int stride_dq; + int stride_dk; + int stride_dv; + int stride_dbias; + int nhead_stride_q; + int nhead_stride_k; + int nhead_stride_v; + int nhead_stride_bias; + int nhead_stride_o; + int nhead_stride_randval; + int nhead_stride_do; + int nhead_stride_lsed; + int nhead_stride_dq_acc; + int nhead_stride_dq; + int nhead_stride_dk; + int nhead_stride_dv; + int nhead_stride_dbias; + int batch_stride_q; + int batch_stride_k; + int batch_stride_v; + int batch_stride_bias; + int batch_stride_o; + int batch_stride_randval; + int batch_stride_do; + int batch_stride_lsed; + int batch_stride_dq_acc; + int batch_stride_dq; + int batch_stride_dk; + int batch_stride_dv; + int batch_stride_dbias; + int split_stride_dq_acc; + int window_size_left; + int window_size_right; + float p_drop; + float p_undrop; + std::variant, std::pair> + drop_seed_offset; }; -struct __attribute__((packed)) fmha_bwd_v3_gen_args +struct __attribute__((packed)) fmha_bwd_dqdkdv_args { - void* ptr_dq; + void *ptr_dq; // 0x00: dq or dq_acc p2 _p0; - void* ptr_dk; + void *ptr_dk; // 0x10 p2 _p1; - void* ptr_dv; + void *ptr_dv; // 0x20 p2 _p2; - const void* ptr_q; + const void *ptr_q; // 0x30 p2 _p3; - const void* ptr_k; + const void *ptr_k; // 0x40 p2 _p4; - const void* ptr_v; + const void *ptr_v; // 0x50 p2 _p5; - const void* ptr_do; + const void *ptr_do; // 0x60 p2 _p6; - const void* ptr_lse; + const void *ptr_lse; // 0x70 p2 _p7; - const void* ptr_d; + const void *ptr_d; // 0x80 p2 _p8; - float scalar; + float scalar; // 0x90 p3 _p9; - float log2e; + float log2e; // 0xa0 p3 _p10; - unsigned int seq_len; + unsigned int seqlen_q; // 0xb0: s_seq_len_q p3 _p11; - unsigned int Ts; + unsigned int Ts; // 0xc0: s_Seqs_k*sub_K p3 _p12; - unsigned int Hs; + unsigned int Hs_q; // 0xd0: s_Hs_q p3 _p13; - unsigned int BAs; + unsigned int BAs_q; // 0xe0: s_BAs_q p3 _p14; - unsigned int Seqs; + unsigned int Seqs_q; // 0xf0: s_Seqs_q p3 _p15; - unsigned int ratio; + unsigned int ratio; // 0x100 p3 _p16; - unsigned int Hs_kv; + unsigned int Hs_k; // 0x110: s_Hs_k p3 _p17; - unsigned int BAs_kv; + unsigned int BAs_k; // 0x120: s_BAs_k p3 _p18; - unsigned int Seqs_kv; + unsigned int Seqs_k; // 0x130: s_Seqs_k p3 _p19; - unsigned int Seqs_dkv; + unsigned int Seqs_dk; // 0x140: s_Seqs_dk p3 _p20; - unsigned int head_dim; + unsigned int seqlen_k; // 0x150: batch mode p3 _p21; -}; - -struct __attribute__((packed)) fmha_bwd_v3_genl_args -{ - void* ptr_dq; - void* ptr_dk; - void* ptr_dv; - const void* ptr_q; - const void* ptr_k; - const void* ptr_v; - const void* ptr_do; - const void* ptr_lse; - const void* ptr_d; - float scalar; - p1 _p0; - float log2e; - p1 _p1; - unsigned int ratio; - p1 _p2; - unsigned int seqlen_q; - p1 _p3; - unsigned int seqlen_k; - p1 _p4; - unsigned int head_dim; - p1 _p5; - unsigned int nhead_q; - p1 _p6; - unsigned int Hs_q; - p1 _p7; - unsigned int BAs_q; - p1 _p8; - unsigned int Seqs_q; - p1 _p9; - unsigned int Hs_k; - p1 _p10; - unsigned int BAs_k; - p1 _p11; - unsigned int Seqs_k; - p1 _p12; - unsigned int Hs_v; - p1 _p13; - unsigned int BAs_v; - p1 _p14; - unsigned int Seqs_v; - p1 _p15; - unsigned int Hs_do; - p1 _p16; - unsigned int BAs_do; - p1 _p17; - unsigned int Seqs_do; - p1 _p18; - unsigned int Hs_dk; - p1 _p19; - unsigned int BAs_dk; - p1 _p20; - unsigned int Seqs_dk; - p1 _p21; - unsigned int Hs_dv; - p1 _p22; - unsigned int BAs_dv; - p1 _p23; - unsigned int Seqs_dv; - p1 _p24; -}; - -struct __attribute__((packed)) fmha_bwd_v3_group_args -{ - void* ptr_dq; - void* ptr_dk; - void* ptr_dv; - const void* ptr_q; - const void* ptr_k; - const void* ptr_v; - const void* ptr_do; - const void* ptr_lse; - const void* ptr_d; - const void* ptr_qseq; - const void* ptr_kseq; - const void* ptr_qseq_padded; - const void* ptr_kseq_padded; - float scalar; - p1 _p0; - float log2e; - p1 _p1; - unsigned int ratio; - p1 _p2; - unsigned int Hs_lsed; - p1 _p3; - unsigned int seqlen_k; // total length of k sequences - p1 _p4; - unsigned int Hs_q; - p1 _p5; - unsigned int Seqs_q; - p1 _p6; - unsigned int Hs_k; - p1 _p7; - unsigned int Seqs_k; - p1 _p8; - unsigned int Hs_v; - p1 _p9; - unsigned int Seqs_v; - p1 _p10; - unsigned int Hs_do; - p1 _p11; - unsigned int Seqs_do; - p1 _p12; - unsigned int Hs_dk; - p1 _p13; - unsigned int Seqs_dk; - p1 _p14; - unsigned int Hs_dv; - p1 _p15; - unsigned int Seqs_dv; - p1 _p16; - unsigned int head_dim; - p1 _p17; -}; - -struct __attribute__((packed)) fmha_bwd_v3_swa_genl_args -{ - void* ptr_dq; - void* ptr_dk; - void* ptr_dv; - const void* ptr_q; - const void* ptr_k; - const void* ptr_v; - const void* ptr_do; - const void* ptr_lse; - const void* ptr_d; - float scalar; - p1 _p0; - float log2e; - p1 _p1; - unsigned int ratio; - p1 _p2; - unsigned int seqlen_q; - p1 _p3; - unsigned int seqlen_k; - p1 _p4; - unsigned int head_dim; - p1 _p5; - unsigned int nhead_q; - p1 _p6; - unsigned int Hs_q; - p1 _p7; - unsigned int BAs_q; - p1 _p8; - unsigned int Seqs_q; - p1 _p9; - unsigned int Hs_k; - p1 _p10; - unsigned int BAs_k; - p1 _p11; - unsigned int Seqs_k; - p1 _p12; - unsigned int Hs_v; - p1 _p13; - unsigned int BAs_v; - p1 _p14; - unsigned int Seqs_v; - p1 _p15; - unsigned int Hs_do; - p1 _p16; - unsigned int BAs_do; - p1 _p17; - unsigned int Seqs_do; - p1 _p18; - unsigned int Hs_dk; - p1 _p19; - unsigned int BAs_dk; - p1 _p20; - unsigned int Seqs_dk; - p1 _p21; - unsigned int Hs_dv; - p1 _p22; - unsigned int BAs_dv; - p1 _p23; - unsigned int Seqs_dv; - p1 _p24; - int mask_x; - p1 _p25; - int mask_y; - p1 _p26; + unsigned int head_dim_q; // 0x160: batch&group mode for headdim padding + p3 _p22; + unsigned int head_dim_v; // 0x170: batch&group mode for headdim padding + p3 _p23; + unsigned int nhead_q; // 0x180: batch mode lsed([B,H,S]) addr = batch_idx * nhead_q * seqlen_q * 4 + head_idx * seqlen_q * 4 + p3 _p24; + unsigned int Hs_v; // 0x190: batch&group mode + p3 _p25; + unsigned int BAs_v; // 0x1a0: batch mode + p3 _p26; + unsigned int Seqs_v; // 0x1b0: batch&group mode + p3 _p27; + unsigned int Hs_do; // 0x1c0: batch&group mode + p3 _p28; + unsigned int BAs_do; // 0x1d0: group mode + p3 _p29; + unsigned int Seqs_do; // 0x1e0: batch&group mode + p3 _p30; + unsigned int Hs_dk; // 0x1f0: batch&group mode + p3 _p31; + unsigned int BAs_dk; // 0x200: group mode + p3 _p32; + unsigned int Hs_dv; // 0x210: batch&group mode + p3 _p33; + unsigned int BAs_dv; // 0x220: group mode + p3 _p34; + unsigned int Seqs_dv; // 0x230: batch&group mode + p3 _p35; + unsigned int Hs_lsed; // 0x240: group mode lsed([H,TotalValid_Q(90)]) addr = seqstart_q[batch_idx] * 4 + head_idx * nhead_stride_lsed(s_Hs_lsed) + p3 _p36; + const void *ptr_qseq; // 0x250: group mode seqstart_q [0, 20, 50, 90] + p2 _p37; + const void *ptr_kseq; // 0x260: group mode seqstart_k [0, 50, 110, 180] + p2 _p38; + const void *ptr_qseq_padded; // 0x270: group mode seqstart_q_padded [0, 30(20+10), 70(20+10+30+10), 120(20+10+30+10+40+10)] if 10 is padded after each seqlen[30(20+10), 40(30+10), 50(40+10)] + p2 _p39; + const void *ptr_kseq_padded; // 0x280: group mode seqstart_k_padded [0, 60(50+10), 130(50+10+60+10), 200(50+10+60+10+70+10)] if 10 is padded after each seqlen[60(50+10), 70(60+10), 80(70+10)] + p2 _p40; + unsigned int max_seqlen_dq; // 0x290: gorup mode max seqlen q for a16 dq_acc store, padding to 16x + p3 _p41; + int mask_x; // 0x2a0 + p3 _p42; + int mask_y; // 0x2b0 + p3 _p43; }; struct __attribute__((packed)) fmha_bwd_odo_args @@ -399,67 +323,8 @@ struct __attribute__((packed)) fmha_bwd_post_kernel_args p2 _p11; }; -struct fmha_bwd_v3_traits -{ - int b; - int h; - int sq; - int sk; - int d; - - int mask; - int ts_qo; - int ts_kv; - - int ts_pre_kernel = 128; - int ts_post_kernel = 64; -}; - -template -struct fmha_bwd_dq_dk_dv_v3_traits_ -{ - static constexpr ck_tile::index_t HDim_q = HDim_q_; - static constexpr ck_tile::index_t HDim_v = HDim_v_; - using DataType = ck_tile::remove_cvref_t; - static constexpr int mask_type = mask_type_; - static constexpr bool kIsAtomic32 = kIsAtomic32_; - static constexpr ck_tile::index_t BF16Cvt = BF16Cvt_; - static constexpr bool kIsSEQPad = kIsSEQPad_; - static constexpr bool kIsHDPad = kIsHDPad_; - static constexpr bool kIsGroupMode = kIsGroupMode_; -}; - -template -struct FmhaBwdV3Name; -template -struct FmhaBwdV3Buf; -template -struct FmhaBwdV3Ts; +__attribute__((visibility("default"))) float mha_bwd(mha_bwd_args args); -namespace gfx942 { -float fmha_bwd_v3(mha_bwd_traits t, - mha_bwd_args a, - const ck_tile::stream_config& s, - const void* seqlen_q_padded = nullptr, - const void* seqlen_k_padded = nullptr, - bool is_v3_api_check = false); -} +float fmha_v3_bwd(mha_bwd_args a); -namespace gfx950 { -float fmha_bwd_v3(mha_bwd_traits t, - mha_bwd_args a, - const ck_tile::stream_config& s, - const void* seqlen_q_padded = nullptr, - const void* seqlen_k_padded = nullptr, - bool is_v3_api_check = false); -} } // namespace aiter diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtna.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtna.co new file mode 100755 index 000000000..d379ef944 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtna_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtna_group.co new file mode 100755 index 000000000..5717e94ab Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtne.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtne.co new file mode 100755 index 000000000..706c37dcf Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtne_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtne_group.co new file mode 100755 index 000000000..fbeecddfe Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtz.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtz.co new file mode 100755 index 000000000..50a6f0cb9 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtz_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtz_group.co new file mode 100755 index 000000000..ce3e7d9bd Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_bf16_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_fp16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_fp16.co new file mode 100755 index 000000000..d7b48a946 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_fp16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_fp16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_fp16_group.co new file mode 100755 index 000000000..d983d7a93 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_dq_convert_fp16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_bf16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_bf16.co new file mode 100755 index 000000000..c5cd19b4c Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_bf16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_bf16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_bf16_group.co new file mode 100755 index 000000000..91767ff58 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_bf16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_fp16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_fp16.co new file mode 100755 index 000000000..d9fab582f Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_fp16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_fp16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_fp16_group.co new file mode 100755 index 000000000..4a8d3d0ca Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd128_odo_fp16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtna.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtna.co new file mode 100755 index 000000000..e3ecf68e9 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtna_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtna_group.co new file mode 100755 index 000000000..6e2fe66f3 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtne.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtne.co new file mode 100755 index 000000000..1f185d257 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtne_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtne_group.co new file mode 100755 index 000000000..a7de08a94 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtz.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtz.co new file mode 100755 index 000000000..0fd2fa14e Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtz_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtz_group.co new file mode 100755 index 000000000..b6c43d47f Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_bf16_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_fp16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_fp16.co new file mode 100755 index 000000000..34fb00289 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_fp16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_fp16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_fp16_group.co new file mode 100755 index 000000000..6d74c6319 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_dq_convert_fp16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_bf16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_bf16.co new file mode 100755 index 000000000..1f114c44e Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_bf16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_bf16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_bf16_group.co new file mode 100755 index 000000000..b418b35f1 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_bf16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_fp16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_fp16.co new file mode 100755 index 000000000..bfd112f73 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_fp16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_fp16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_fp16_group.co new file mode 100755 index 000000000..f83769509 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd192_odo_fp16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtna.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtna.co new file mode 100755 index 000000000..91da641b0 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtna_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtna_group.co new file mode 100755 index 000000000..381ae73c5 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtne.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtne.co new file mode 100755 index 000000000..87c7c5184 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtne_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtne_group.co new file mode 100755 index 000000000..dea1c120e Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtz.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtz.co new file mode 100755 index 000000000..5fc1e6583 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtz_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtz_group.co new file mode 100755 index 000000000..f0c689bb1 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_bf16_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_fp16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_fp16.co new file mode 100755 index 000000000..8aa285f9c Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_fp16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_fp16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_fp16_group.co new file mode 100755 index 000000000..3dc89a48e Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_dq_convert_fp16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_bf16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_bf16.co new file mode 100755 index 000000000..d0735a5bd Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_bf16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_bf16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_bf16_group.co new file mode 100755 index 000000000..c5349a960 Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_bf16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_fp16.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_fp16.co new file mode 100755 index 000000000..5590aba8b Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_fp16.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_fp16_group.co b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_fp16_group.co new file mode 100755 index 000000000..43f85f8ab Binary files /dev/null and b/hsa/gfx942/fmha_v3_bwd/bwd_hd64_odo_fp16_group.co differ diff --git a/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dq_convert.csv b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dq_convert.csv new file mode 100644 index 000000000..ff437a30d --- /dev/null +++ b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dq_convert.csv @@ -0,0 +1,25 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +fp16,64,64,0,0,0,0,0,3,0,64,_ZN5aiter29fmha_bwd_hd64_dq_convert_fp16E,bwd_hd64_dq_convert_fp16.co +fp16,64,64,0,0,0,0,1,3,0,64,_ZN5aiter35fmha_bwd_hd64_dq_convert_fp16_groupE,bwd_hd64_dq_convert_fp16_group.co +bf16,64,64,0,0,0,0,0,0,0,64,_ZN5aiter34fmha_bwd_hd64_dq_convert_bf16_rtneE,bwd_hd64_dq_convert_bf16_rtne.co +bf16,64,64,0,0,0,0,0,1,0,64,_ZN5aiter34fmha_bwd_hd64_dq_convert_bf16_rtnaE,bwd_hd64_dq_convert_bf16_rtna.co +bf16,64,64,0,0,0,0,0,2,0,64,_ZN5aiter33fmha_bwd_hd64_dq_convert_bf16_rtzE,bwd_hd64_dq_convert_bf16_rtz.co +bf16,64,64,0,0,0,0,1,0,0,64,_ZN5aiter40fmha_bwd_hd64_dq_convert_bf16_rtne_groupE,bwd_hd64_dq_convert_bf16_rtne_group.co +bf16,64,64,0,0,0,0,1,1,0,64,_ZN5aiter40fmha_bwd_hd64_dq_convert_bf16_rtna_groupE,bwd_hd64_dq_convert_bf16_rtna_group.co +bf16,64,64,0,0,0,0,1,2,0,64,_ZN5aiter39fmha_bwd_hd64_dq_convert_bf16_rtz_groupE,bwd_hd64_dq_convert_bf16_rtz_group.co +fp16,128,128,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd128_dq_convert_fp16E,bwd_hd128_dq_convert_fp16.co +fp16,128,128,0,0,0,0,1,3,0,64,_ZN5aiter36fmha_bwd_hd128_dq_convert_fp16_groupE,bwd_hd128_dq_convert_fp16_group.co +bf16,128,128,0,0,0,0,0,0,0,64,_ZN5aiter35fmha_bwd_hd128_dq_convert_bf16_rtneE,bwd_hd128_dq_convert_bf16_rtne.co +bf16,128,128,0,0,0,0,0,1,0,64,_ZN5aiter35fmha_bwd_hd128_dq_convert_bf16_rtnaE,bwd_hd128_dq_convert_bf16_rtna.co +bf16,128,128,0,0,0,0,0,2,0,64,_ZN5aiter34fmha_bwd_hd128_dq_convert_bf16_rtzE,bwd_hd128_dq_convert_bf16_rtz.co +bf16,128,128,0,0,0,0,1,0,0,64,_ZN5aiter41fmha_bwd_hd128_dq_convert_bf16_rtne_groupE,bwd_hd128_dq_convert_bf16_rtne_group.co +bf16,128,128,0,0,0,0,1,1,0,64,_ZN5aiter41fmha_bwd_hd128_dq_convert_bf16_rtna_groupE,bwd_hd128_dq_convert_bf16_rtna_group.co +bf16,128,128,0,0,0,0,1,2,0,64,_ZN5aiter40fmha_bwd_hd128_dq_convert_bf16_rtz_groupE,bwd_hd128_dq_convert_bf16_rtz_group.co +fp16,192,192,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd192_dq_convert_fp16E,bwd_hd192_dq_convert_fp16.co +fp16,192,192,0,0,0,0,1,3,0,64,_ZN5aiter36fmha_bwd_hd192_dq_convert_fp16_groupE,bwd_hd192_dq_convert_fp16_group.co +bf16,192,192,0,0,0,0,0,0,0,64,_ZN5aiter35fmha_bwd_hd192_dq_convert_bf16_rtneE,bwd_hd192_dq_convert_bf16_rtne.co +bf16,192,192,0,0,0,0,0,1,0,64,_ZN5aiter35fmha_bwd_hd192_dq_convert_bf16_rtnaE,bwd_hd192_dq_convert_bf16_rtna.co +bf16,192,192,0,0,0,0,0,2,0,64,_ZN5aiter34fmha_bwd_hd192_dq_convert_bf16_rtzE,bwd_hd192_dq_convert_bf16_rtz.co +bf16,192,192,0,0,0,0,1,0,0,64,_ZN5aiter41fmha_bwd_hd192_dq_convert_bf16_rtne_groupE,bwd_hd192_dq_convert_bf16_rtne_group.co +bf16,192,192,0,0,0,0,1,1,0,64,_ZN5aiter41fmha_bwd_hd192_dq_convert_bf16_rtna_groupE,bwd_hd192_dq_convert_bf16_rtna_group.co +bf16,192,192,0,0,0,0,1,2,0,64,_ZN5aiter40fmha_bwd_hd192_dq_convert_bf16_rtz_groupE,bwd_hd192_dq_convert_bf16_rtz_group.co diff --git a/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dq_shuffle.csv b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dq_shuffle.csv new file mode 100644 index 000000000..ba38ee57c --- /dev/null +++ b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dq_shuffle.csv @@ -0,0 +1 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name diff --git a/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dqdkdv.csv b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dqdkdv.csv new file mode 100644 index 000000000..3c4a67ac1 --- /dev/null +++ b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_dqdkdv.csv @@ -0,0 +1,121 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +fp16,192,192,0,1,1,1,0,3,16,64,_ZN5aiter31fmha_bwd_hd192_fp16_a32_psskddvE,bwd_hd192_fp16_a32_psskddv.co +fp16,128,128,1,1,1,0,1,3,16,192,_ZN5aiter41fmha_bwd_hd128_fp16_causal_a32_pssk_groupE,bwd_hd128_fp16_causal_a32_pssk_group.co +fp16,128,128,1,1,1,1,1,3,16,192,_ZN5aiter44fmha_bwd_hd128_fp16_causal_a32_psskddv_groupE,bwd_hd128_fp16_causal_a32_psskddv_group.co +fp16,128,128,0,0,0,0,0,3,16,192,_ZN5aiter23fmha_bwd_hd128_fp16_a16E,bwd_hd128_fp16_a16.co +fp16,128,128,0,0,0,1,0,3,16,192,_ZN5aiter28fmha_bwd_hd128_fp16_a16_pddvE,bwd_hd128_fp16_a16_pddv.co +fp16,128,128,1,0,0,0,0,3,16,192,_ZN5aiter30fmha_bwd_hd128_fp16_causal_a16E,bwd_hd128_fp16_causal_a16.co +fp16,128,128,1,0,0,1,0,3,16,192,_ZN5aiter35fmha_bwd_hd128_fp16_causal_a16_pddvE,bwd_hd128_fp16_causal_a16_pddv.co +fp16,128,128,0,1,0,0,0,3,16,192,_ZN5aiter23fmha_bwd_hd128_fp16_a32E,bwd_hd128_fp16_a32.co +fp16,128,128,0,1,1,1,0,3,16,192,_ZN5aiter31fmha_bwd_hd128_fp16_a32_psskddvE,bwd_hd128_fp16_a32_psskddv.co +fp16,128,128,0,1,1,0,1,3,16,192,_ZN5aiter34fmha_bwd_hd128_fp16_a32_pssk_groupE,bwd_hd128_fp16_a32_pssk_group.co +fp16,128,128,0,1,1,1,1,3,16,192,_ZN5aiter37fmha_bwd_hd128_fp16_a32_psskddv_groupE,bwd_hd128_fp16_a32_psskddv_group.co +fp16,128,128,1,1,0,0,0,3,16,192,_ZN5aiter30fmha_bwd_hd128_fp16_causal_a32E,bwd_hd128_fp16_causal_a32.co +fp16,128,128,1,1,1,1,0,3,16,192,_ZN5aiter38fmha_bwd_hd128_fp16_causal_a32_psskddvE,bwd_hd128_fp16_causal_a32_psskddv.co +fp16,128,128,3,1,1,1,0,3,16,192,_ZN5aiter35fmha_bwd_hd128_fp16_swa_a32_psskddvE,bwd_hd128_fp16_swa_a32_psskddv.co +fp16,192,192,1,1,1,1,0,3,16,64,_ZN5aiter38fmha_bwd_hd192_fp16_causal_a32_psskddvE,bwd_hd192_fp16_causal_a32_psskddv.co +fp16,64,64,0,0,0,0,0,3,32,192,_ZN5aiter22fmha_bwd_hd64_fp16_a16E,bwd_hd64_fp16_a16.co +fp16,64,64,1,0,0,0,0,3,32,192,_ZN5aiter29fmha_bwd_hd64_fp16_causal_a16E,bwd_hd64_fp16_causal_a16.co +fp16,64,64,0,1,1,0,0,3,32,192,_ZN5aiter27fmha_bwd_hd64_fp16_a32_psskE,bwd_hd64_fp16_a32_pssk.co +fp16,64,64,1,1,1,0,0,3,32,192,_ZN5aiter34fmha_bwd_hd64_fp16_causal_a32_psskE,bwd_hd64_fp16_causal_a32_pssk.co +fp16,128,128,2,1,1,1,0,3,16,192,_ZN5aiter41fmha_bwd_hd128_fp16_causal_br_a32_psskddvE,bwd_hd128_fp16_causal_br_a32_psskddv.co +fp16,128,128,2,1,1,0,1,3,16,192,_ZN5aiter44fmha_bwd_hd128_fp16_causal_br_a32_pssk_groupE,bwd_hd128_fp16_causal_br_a32_pssk_group.co +fp16,128,128,2,1,1,1,1,3,16,192,_ZN5aiter47fmha_bwd_hd128_fp16_causal_br_a32_psskddv_groupE,bwd_hd128_fp16_causal_br_a32_psskddv_group.co +fp16,192,192,0,1,1,1,1,3,16,64,_ZN5aiter37fmha_bwd_hd192_fp16_a32_psskddv_groupE,bwd_hd192_fp16_a32_psskddv_group.co +fp16,192,192,2,1,1,1,0,3,16,64,_ZN5aiter41fmha_bwd_hd192_fp16_causal_br_a32_psskddvE,bwd_hd192_fp16_causal_br_a32_psskddv.co +fp16,192,192,2,1,1,1,1,3,16,64,_ZN5aiter47fmha_bwd_hd192_fp16_causal_br_a32_psskddv_groupE,bwd_hd192_fp16_causal_br_a32_psskddv_group.co +fp16,192,192,1,1,1,1,1,3,16,64,_ZN5aiter44fmha_bwd_hd192_fp16_causal_a32_psskddv_groupE,bwd_hd192_fp16_causal_a32_psskddv_group.co +fp16,64,64,0,1,1,0,1,3,32,192,_ZN5aiter33fmha_bwd_hd64_fp16_a32_pssk_groupE,bwd_hd64_fp16_a32_pssk_group.co +fp16,64,64,2,1,1,0,0,3,32,192,_ZN5aiter37fmha_bwd_hd64_fp16_causal_br_a32_psskE,bwd_hd64_fp16_causal_br_a32_pssk.co +fp16,64,64,2,1,1,0,1,3,32,192,_ZN5aiter43fmha_bwd_hd64_fp16_causal_br_a32_pssk_groupE,bwd_hd64_fp16_causal_br_a32_pssk_group.co +fp16,64,64,1,1,1,0,1,3,32,192,_ZN5aiter40fmha_bwd_hd64_fp16_causal_a32_pssk_groupE,bwd_hd64_fp16_causal_a32_pssk_group.co +bf16,64,64,2,1,1,0,1,0,32,192,_ZN5aiter48fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_groupE,bwd_hd64_bf16_causal_br_a32_rtne_pssk_group.co +bf16,64,64,2,1,1,0,1,1,32,192,_ZN5aiter48fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_groupE,bwd_hd64_bf16_causal_br_a32_rtna_pssk_group.co +bf16,64,64,2,1,1,0,1,2,32,192,_ZN5aiter47fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_groupE,bwd_hd64_bf16_causal_br_a32_rtz_pssk_group.co +bf16,64,64,1,1,1,0,0,0,32,192,_ZN5aiter39fmha_bwd_hd64_bf16_causal_a32_rtne_psskE,bwd_hd64_bf16_causal_a32_rtne_pssk.co +bf16,64,64,1,1,1,0,0,1,32,192,_ZN5aiter39fmha_bwd_hd64_bf16_causal_a32_rtna_psskE,bwd_hd64_bf16_causal_a32_rtna_pssk.co +bf16,64,64,1,1,1,0,0,2,32,192,_ZN5aiter38fmha_bwd_hd64_bf16_causal_a32_rtz_psskE,bwd_hd64_bf16_causal_a32_rtz_pssk.co +bf16,64,64,1,1,1,0,1,0,32,192,_ZN5aiter45fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_groupE,bwd_hd64_bf16_causal_a32_rtne_pssk_group.co +bf16,64,64,1,1,1,0,1,1,32,192,_ZN5aiter45fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_groupE,bwd_hd64_bf16_causal_a32_rtna_pssk_group.co +bf16,64,64,1,1,1,0,1,2,32,192,_ZN5aiter44fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_groupE,bwd_hd64_bf16_causal_a32_rtz_pssk_group.co +bf16,192,192,0,1,1,1,1,0,16,64,_ZN5aiter42fmha_bwd_hd192_bf16_a32_rtne_psskddv_groupE,bwd_hd192_bf16_a32_rtne_psskddv_group.co +bf16,192,192,0,1,1,1,1,1,16,64,_ZN5aiter42fmha_bwd_hd192_bf16_a32_rtna_psskddv_groupE,bwd_hd192_bf16_a32_rtna_psskddv_group.co +bf16,192,192,0,1,1,1,1,2,16,64,_ZN5aiter41fmha_bwd_hd192_bf16_a32_rtz_psskddv_groupE,bwd_hd192_bf16_a32_rtz_psskddv_group.co +bf16,192,192,0,1,1,1,0,0,16,64,_ZN5aiter36fmha_bwd_hd192_bf16_a32_rtne_psskddvE,bwd_hd192_bf16_a32_rtne_psskddv.co +bf16,192,192,0,1,1,1,0,1,16,64,_ZN5aiter36fmha_bwd_hd192_bf16_a32_rtna_psskddvE,bwd_hd192_bf16_a32_rtna_psskddv.co +bf16,192,192,0,1,1,1,0,2,16,64,_ZN5aiter35fmha_bwd_hd192_bf16_a32_rtz_psskddvE,bwd_hd192_bf16_a32_rtz_psskddv.co +bf16,192,192,2,1,1,1,0,0,16,64,_ZN5aiter46fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddvE,bwd_hd192_bf16_causal_br_a32_rtne_psskddv.co +bf16,192,192,2,1,1,1,0,1,16,64,_ZN5aiter46fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddvE,bwd_hd192_bf16_causal_br_a32_rtna_psskddv.co +bf16,192,192,2,1,1,1,0,2,16,64,_ZN5aiter45fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddvE,bwd_hd192_bf16_causal_br_a32_rtz_psskddv.co +bf16,192,192,2,1,1,1,1,0,16,64,_ZN5aiter52fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_groupE,bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group.co +bf16,192,192,2,1,1,1,1,1,16,64,_ZN5aiter52fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_groupE,bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group.co +bf16,192,192,2,1,1,1,1,2,16,64,_ZN5aiter51fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_groupE,bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group.co +bf16,192,192,1,1,1,1,0,0,16,64,_ZN5aiter43fmha_bwd_hd192_bf16_causal_a32_rtne_psskddvE,bwd_hd192_bf16_causal_a32_rtne_psskddv.co +bf16,192,192,1,1,1,1,0,1,16,64,_ZN5aiter43fmha_bwd_hd192_bf16_causal_a32_rtna_psskddvE,bwd_hd192_bf16_causal_a32_rtna_psskddv.co +bf16,192,192,1,1,1,1,0,2,16,64,_ZN5aiter42fmha_bwd_hd192_bf16_causal_a32_rtz_psskddvE,bwd_hd192_bf16_causal_a32_rtz_psskddv.co +bf16,192,192,1,1,1,1,1,0,16,64,_ZN5aiter49fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_groupE,bwd_hd192_bf16_causal_a32_rtne_psskddv_group.co +bf16,192,192,1,1,1,1,1,1,16,64,_ZN5aiter49fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_groupE,bwd_hd192_bf16_causal_a32_rtna_psskddv_group.co +bf16,192,192,1,1,1,1,1,2,16,64,_ZN5aiter48fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_groupE,bwd_hd192_bf16_causal_a32_rtz_psskddv_group.co +bf16,64,64,1,0,0,0,0,0,32,192,_ZN5aiter34fmha_bwd_hd64_bf16_causal_a16_rtneE,bwd_hd64_bf16_causal_a16_rtne.co +bf16,64,64,1,0,0,0,0,1,32,192,_ZN5aiter34fmha_bwd_hd64_bf16_causal_a16_rtnaE,bwd_hd64_bf16_causal_a16_rtna.co +bf16,64,64,1,0,0,0,0,2,32,192,_ZN5aiter33fmha_bwd_hd64_bf16_causal_a16_rtzE,bwd_hd64_bf16_causal_a16_rtz.co +bf16,64,64,0,1,1,0,0,0,32,192,_ZN5aiter32fmha_bwd_hd64_bf16_a32_rtne_psskE,bwd_hd64_bf16_a32_rtne_pssk.co +bf16,64,64,0,1,1,0,0,1,32,192,_ZN5aiter32fmha_bwd_hd64_bf16_a32_rtna_psskE,bwd_hd64_bf16_a32_rtna_pssk.co +bf16,64,64,0,1,1,0,0,2,32,192,_ZN5aiter31fmha_bwd_hd64_bf16_a32_rtz_psskE,bwd_hd64_bf16_a32_rtz_pssk.co +bf16,64,64,0,1,1,0,1,0,32,192,_ZN5aiter38fmha_bwd_hd64_bf16_a32_rtne_pssk_groupE,bwd_hd64_bf16_a32_rtne_pssk_group.co +bf16,64,64,0,1,1,0,1,1,32,192,_ZN5aiter38fmha_bwd_hd64_bf16_a32_rtna_pssk_groupE,bwd_hd64_bf16_a32_rtna_pssk_group.co +bf16,64,64,0,1,1,0,1,2,32,192,_ZN5aiter37fmha_bwd_hd64_bf16_a32_rtz_pssk_groupE,bwd_hd64_bf16_a32_rtz_pssk_group.co +bf16,64,64,2,1,1,0,0,0,32,192,_ZN5aiter42fmha_bwd_hd64_bf16_causal_br_a32_rtne_psskE,bwd_hd64_bf16_causal_br_a32_rtne_pssk.co +bf16,64,64,2,1,1,0,0,1,32,192,_ZN5aiter42fmha_bwd_hd64_bf16_causal_br_a32_rtna_psskE,bwd_hd64_bf16_causal_br_a32_rtna_pssk.co +bf16,64,64,2,1,1,0,0,2,32,192,_ZN5aiter41fmha_bwd_hd64_bf16_causal_br_a32_rtz_psskE,bwd_hd64_bf16_causal_br_a32_rtz_pssk.co +bf16,128,128,0,0,0,0,0,0,16,192,_ZN5aiter28fmha_bwd_hd128_bf16_a16_rtneE,bwd_hd128_bf16_a16_rtne.co +bf16,128,128,0,0,0,0,0,1,16,192,_ZN5aiter28fmha_bwd_hd128_bf16_a16_rtnaE,bwd_hd128_bf16_a16_rtna.co +bf16,128,128,0,0,0,0,0,2,16,192,_ZN5aiter27fmha_bwd_hd128_bf16_a16_rtzE,bwd_hd128_bf16_a16_rtz.co +bf16,128,128,0,0,0,1,0,0,16,192,_ZN5aiter33fmha_bwd_hd128_bf16_a16_rtne_pddvE,bwd_hd128_bf16_a16_rtne_pddv.co +bf16,128,128,0,0,0,1,0,1,16,192,_ZN5aiter33fmha_bwd_hd128_bf16_a16_rtna_pddvE,bwd_hd128_bf16_a16_rtna_pddv.co +bf16,128,128,0,0,0,1,0,2,16,192,_ZN5aiter32fmha_bwd_hd128_bf16_a16_rtz_pddvE,bwd_hd128_bf16_a16_rtz_pddv.co +bf16,128,128,1,0,0,0,0,0,16,192,_ZN5aiter35fmha_bwd_hd128_bf16_causal_a16_rtneE,bwd_hd128_bf16_causal_a16_rtne.co +bf16,128,128,1,0,0,0,0,1,16,192,_ZN5aiter35fmha_bwd_hd128_bf16_causal_a16_rtnaE,bwd_hd128_bf16_causal_a16_rtna.co +bf16,128,128,1,0,0,0,0,2,16,192,_ZN5aiter34fmha_bwd_hd128_bf16_causal_a16_rtzE,bwd_hd128_bf16_causal_a16_rtz.co +bf16,128,128,1,0,0,1,0,0,16,192,_ZN5aiter40fmha_bwd_hd128_bf16_causal_a16_rtne_pddvE,bwd_hd128_bf16_causal_a16_rtne_pddv.co +bf16,128,128,1,0,0,1,0,1,16,192,_ZN5aiter40fmha_bwd_hd128_bf16_causal_a16_rtna_pddvE,bwd_hd128_bf16_causal_a16_rtna_pddv.co +bf16,128,128,1,0,0,1,0,2,16,192,_ZN5aiter39fmha_bwd_hd128_bf16_causal_a16_rtz_pddvE,bwd_hd128_bf16_causal_a16_rtz_pddv.co +bf16,128,128,0,1,0,0,0,0,16,192,_ZN5aiter28fmha_bwd_hd128_bf16_a32_rtneE,bwd_hd128_bf16_a32_rtne.co +bf16,128,128,0,1,0,0,0,1,16,192,_ZN5aiter28fmha_bwd_hd128_bf16_a32_rtnaE,bwd_hd128_bf16_a32_rtna.co +bf16,128,128,0,1,0,0,0,2,16,192,_ZN5aiter27fmha_bwd_hd128_bf16_a32_rtzE,bwd_hd128_bf16_a32_rtz.co +bf16,128,128,0,1,1,1,0,0,16,192,_ZN5aiter36fmha_bwd_hd128_bf16_a32_rtne_psskddvE,bwd_hd128_bf16_a32_rtne_psskddv.co +bf16,128,128,0,1,1,1,0,1,16,192,_ZN5aiter36fmha_bwd_hd128_bf16_a32_rtna_psskddvE,bwd_hd128_bf16_a32_rtna_psskddv.co +bf16,128,128,0,1,1,1,0,2,16,192,_ZN5aiter35fmha_bwd_hd128_bf16_a32_rtz_psskddvE,bwd_hd128_bf16_a32_rtz_psskddv.co +bf16,128,128,0,1,1,0,1,0,16,192,_ZN5aiter39fmha_bwd_hd128_bf16_a32_rtne_pssk_groupE,bwd_hd128_bf16_a32_rtne_pssk_group.co +bf16,128,128,0,1,1,1,1,0,16,192,_ZN5aiter42fmha_bwd_hd128_bf16_a32_rtne_psskddv_groupE,bwd_hd128_bf16_a32_rtne_psskddv_group.co +bf16,128,128,0,1,1,0,1,1,16,192,_ZN5aiter39fmha_bwd_hd128_bf16_a32_rtna_pssk_groupE,bwd_hd128_bf16_a32_rtna_pssk_group.co +bf16,128,128,0,1,1,1,1,1,16,192,_ZN5aiter42fmha_bwd_hd128_bf16_a32_rtna_psskddv_groupE,bwd_hd128_bf16_a32_rtna_psskddv_group.co +bf16,128,128,0,1,1,0,1,2,16,192,_ZN5aiter38fmha_bwd_hd128_bf16_a32_rtz_pssk_groupE,bwd_hd128_bf16_a32_rtz_pssk_group.co +bf16,128,128,0,1,1,1,1,2,16,192,_ZN5aiter41fmha_bwd_hd128_bf16_a32_rtz_psskddv_groupE,bwd_hd128_bf16_a32_rtz_psskddv_group.co +bf16,128,128,2,1,1,1,0,0,16,192,_ZN5aiter46fmha_bwd_hd128_bf16_causal_br_a32_rtne_psskddvE,bwd_hd128_bf16_causal_br_a32_rtne_psskddv.co +bf16,128,128,2,1,1,1,0,1,16,192,_ZN5aiter46fmha_bwd_hd128_bf16_causal_br_a32_rtna_psskddvE,bwd_hd128_bf16_causal_br_a32_rtna_psskddv.co +bf16,128,128,2,1,1,1,0,2,16,192,_ZN5aiter45fmha_bwd_hd128_bf16_causal_br_a32_rtz_psskddvE,bwd_hd128_bf16_causal_br_a32_rtz_psskddv.co +bf16,128,128,2,1,1,0,1,0,16,192,_ZN5aiter49fmha_bwd_hd128_bf16_causal_br_a32_rtne_pssk_groupE,bwd_hd128_bf16_causal_br_a32_rtne_pssk_group.co +bf16,128,128,2,1,1,1,1,0,16,192,_ZN5aiter52fmha_bwd_hd128_bf16_causal_br_a32_rtne_psskddv_groupE,bwd_hd128_bf16_causal_br_a32_rtne_psskddv_group.co +bf16,128,128,2,1,1,0,1,1,16,192,_ZN5aiter49fmha_bwd_hd128_bf16_causal_br_a32_rtna_pssk_groupE,bwd_hd128_bf16_causal_br_a32_rtna_pssk_group.co +bf16,128,128,2,1,1,1,1,1,16,192,_ZN5aiter52fmha_bwd_hd128_bf16_causal_br_a32_rtna_psskddv_groupE,bwd_hd128_bf16_causal_br_a32_rtna_psskddv_group.co +bf16,128,128,2,1,1,0,1,2,16,192,_ZN5aiter48fmha_bwd_hd128_bf16_causal_br_a32_rtz_pssk_groupE,bwd_hd128_bf16_causal_br_a32_rtz_pssk_group.co +bf16,128,128,2,1,1,1,1,2,16,192,_ZN5aiter51fmha_bwd_hd128_bf16_causal_br_a32_rtz_psskddv_groupE,bwd_hd128_bf16_causal_br_a32_rtz_psskddv_group.co +bf16,128,128,1,1,0,0,0,0,16,192,_ZN5aiter35fmha_bwd_hd128_bf16_causal_a32_rtneE,bwd_hd128_bf16_causal_a32_rtne.co +bf16,128,128,1,1,0,0,0,1,16,192,_ZN5aiter35fmha_bwd_hd128_bf16_causal_a32_rtnaE,bwd_hd128_bf16_causal_a32_rtna.co +bf16,128,128,1,1,0,0,0,2,16,192,_ZN5aiter34fmha_bwd_hd128_bf16_causal_a32_rtzE,bwd_hd128_bf16_causal_a32_rtz.co +bf16,128,128,1,1,1,1,0,0,16,192,_ZN5aiter43fmha_bwd_hd128_bf16_causal_a32_rtne_psskddvE,bwd_hd128_bf16_causal_a32_rtne_psskddv.co +bf16,128,128,1,1,1,1,0,1,16,192,_ZN5aiter43fmha_bwd_hd128_bf16_causal_a32_rtna_psskddvE,bwd_hd128_bf16_causal_a32_rtna_psskddv.co +bf16,128,128,1,1,1,1,0,2,16,192,_ZN5aiter42fmha_bwd_hd128_bf16_causal_a32_rtz_psskddvE,bwd_hd128_bf16_causal_a32_rtz_psskddv.co +bf16,128,128,1,1,1,0,1,0,16,192,_ZN5aiter46fmha_bwd_hd128_bf16_causal_a32_rtne_pssk_groupE,bwd_hd128_bf16_causal_a32_rtne_pssk_group.co +bf16,128,128,1,1,1,1,1,0,16,192,_ZN5aiter49fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv_groupE,bwd_hd128_bf16_causal_a32_rtne_psskddv_group.co +bf16,128,128,1,1,1,0,1,1,16,192,_ZN5aiter46fmha_bwd_hd128_bf16_causal_a32_rtna_pssk_groupE,bwd_hd128_bf16_causal_a32_rtna_pssk_group.co +bf16,128,128,1,1,1,1,1,1,16,192,_ZN5aiter49fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv_groupE,bwd_hd128_bf16_causal_a32_rtna_psskddv_group.co +bf16,128,128,1,1,1,0,1,2,16,192,_ZN5aiter45fmha_bwd_hd128_bf16_causal_a32_rtz_pssk_groupE,bwd_hd128_bf16_causal_a32_rtz_pssk_group.co +bf16,128,128,1,1,1,1,1,2,16,192,_ZN5aiter48fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv_groupE,bwd_hd128_bf16_causal_a32_rtz_psskddv_group.co +bf16,128,128,3,1,1,1,0,0,16,192,_ZN5aiter40fmha_bwd_hd128_bf16_swa_a32_rtne_psskddvE,bwd_hd128_bf16_swa_a32_rtne_psskddv.co +bf16,128,128,3,1,1,1,0,1,16,192,_ZN5aiter40fmha_bwd_hd128_bf16_swa_a32_rtna_psskddvE,bwd_hd128_bf16_swa_a32_rtna_psskddv.co +bf16,128,128,3,1,1,1,0,2,16,192,_ZN5aiter39fmha_bwd_hd128_bf16_swa_a32_rtz_psskddvE,bwd_hd128_bf16_swa_a32_rtz_psskddv.co +bf16,64,64,0,0,0,0,0,0,32,192,_ZN5aiter27fmha_bwd_hd64_bf16_a16_rtneE,bwd_hd64_bf16_a16_rtne.co +bf16,64,64,0,0,0,0,0,1,32,192,_ZN5aiter27fmha_bwd_hd64_bf16_a16_rtnaE,bwd_hd64_bf16_a16_rtna.co +bf16,64,64,0,0,0,0,0,2,32,192,_ZN5aiter26fmha_bwd_hd64_bf16_a16_rtzE,bwd_hd64_bf16_a16_rtz.co diff --git a/hsa/gfx942/fmha_v3_bwd/fmha_bwd_odo.csv b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_odo.csv new file mode 100644 index 000000000..c21eeff46 --- /dev/null +++ b/hsa/gfx942/fmha_v3_bwd/fmha_bwd_odo.csv @@ -0,0 +1,13 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +fp16,64,64,0,0,0,0,0,3,0,128,_ZN5aiter22fmha_bwd_hd64_odo_fp16E,bwd_hd64_odo_fp16.co +fp16,64,64,0,0,0,0,1,3,0,128,_ZN5aiter28fmha_bwd_hd64_odo_fp16_groupE,bwd_hd64_odo_fp16_group.co +bf16,64,64,0,0,0,0,0,3,0,128,_ZN5aiter22fmha_bwd_hd64_odo_bf16E,bwd_hd64_odo_bf16.co +bf16,64,64,0,0,0,0,1,3,0,128,_ZN5aiter28fmha_bwd_hd64_odo_bf16_groupE,bwd_hd64_odo_bf16_group.co +fp16,128,128,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd128_odo_fp16E,bwd_hd128_odo_fp16.co +fp16,128,128,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd128_odo_fp16_groupE,bwd_hd128_odo_fp16_group.co +bf16,128,128,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd128_odo_bf16E,bwd_hd128_odo_bf16.co +bf16,128,128,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd128_odo_bf16_groupE,bwd_hd128_odo_bf16_group.co +fp16,192,192,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd192_odo_fp16E,bwd_hd192_odo_fp16.co +fp16,192,192,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd192_odo_fp16_groupE,bwd_hd192_odo_fp16_group.co +bf16,192,192,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd192_odo_bf16E,bwd_hd192_odo_bf16.co +bf16,192,192,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd192_odo_bf16_groupE,bwd_hd192_odo_bf16_group.co diff --git a/hsa/gfx942/fmha_v3_fwd/aiter_hip_common.h b/hsa/gfx942/fmha_v3_fwd/aiter_hip_common.h new file mode 100644 index 000000000..c17f2b352 --- /dev/null +++ b/hsa/gfx942/fmha_v3_fwd/aiter_hip_common.h @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include +#include + +enum class GPUArch +{ + gfx942, + gfx950 +}; + +#define HIP_CALL(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + printf("\n[AITER] %s:%d fail to call %s ---> [HIP error](%s)\n", \ + __FILE__, \ + __LINE__, \ + #call, \ + hipGetErrorString(err)); \ + exit(0); \ + } \ + } while(0) + +struct p3 +{ + unsigned int _p0; + unsigned int _p1; + unsigned int _p2; +}; +struct p2 +{ + unsigned int _p0; + unsigned int _p1; +}; +struct p1 +{ + unsigned int _p0; +}; + +struct AiterAsmKernelArgs +{ + void* args_ptr; + void* arg_size_ptr; + int gdx; + int gdy; + int gdz; + int bdx; + int bdy; + int bdz; + const hipStream_t stream; +}; + +class AiterAsmKernel +{ + private: + hipModule_t module; + hipFunction_t kernel_func; + + public: + AiterAsmKernel(const char* name, const char* hsaco) + { + const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); + std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() + << " GetFunction: " << name; + HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); + HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); + std::cout << " Success" << std::endl; + }; + + ~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); } + + void launch_kernel(const AiterAsmKernelArgs& kargs) + { + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kargs.args_ptr, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + kargs.arg_size_ptr, + HIP_LAUNCH_PARAM_END}; + + HIP_CALL(hipModuleLaunchKernel(kernel_func, + kargs.gdx, + kargs.gdy, + kargs.gdz, + kargs.bdx, + kargs.bdy, + kargs.bdz, + 0, + kargs.stream, + nullptr, + (void**)&config)); + }; +}; + +class AiterAsmKernelFast +{ + private: + hipModule_t module; + hipFunction_t kernel_func; + + public: + AiterAsmKernelFast(const char* name, void* hsaco) + { + HIP_CALL(hipModuleLoadData(&module, hsaco)); + HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); + std::cout << " Success" << std::endl; + }; + + ~AiterAsmKernelFast() { HIP_CALL(hipModuleUnload(module)); } + + void launch_kernel(const AiterAsmKernelArgs& kargs) + { + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kargs.args_ptr, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + kargs.arg_size_ptr, + HIP_LAUNCH_PARAM_END}; + + HIP_CALL(hipModuleLaunchKernel(kernel_func, + kargs.gdx, + kargs.gdy, + kargs.gdz, + kargs.bdx, + kargs.bdy, + kargs.bdz, + 0, + kargs.stream, + nullptr, + (void**)&config)); + }; +}; + +static const std::string get_gpu_arch() +{ + int device_count; + hipError_t err = hipGetDeviceCount(&device_count); + if(err != hipSuccess || device_count == 0) + { + return "No GPU Found"; + } + + hipDeviceProp_t prop; + err = hipGetDeviceProperties(&prop, 0); + if (err != hipSuccess) { + std::cerr << "Failed to get device properties: " << hipGetErrorString(err) << std::endl; + return {}; + } + + std::string arch_full = prop.gcnArchName; + size_t colon_pos = arch_full.find(':'); + if(colon_pos != std::string::npos) + { + return arch_full.substr(0, colon_pos); + } + else + { + return arch_full; + } +} + +static const uint32_t get_num_cu_func() +{ + auto get_num_cu_local = []() { + hipDevice_t dev; + hipDeviceProp_t dev_prop; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }; + static const uint32_t num_cu = get_num_cu_local(); + return num_cu; +} diff --git a/hsa/gfx942/fmha_v3_fwd/fmha_v3_fwd.cpp b/hsa/gfx942/fmha_v3_fwd/fmha_v3_fwd.cpp new file mode 100644 index 000000000..84a73755a --- /dev/null +++ b/hsa/gfx942/fmha_v3_fwd/fmha_v3_fwd.cpp @@ -0,0 +1,133 @@ +#include +#include +#include "toy_format.hpp" +#include "aiter_hip_common.h" +// #include "asm_fmha_v3_fwd.hpp" + +using namespace std; + +namespace aiter { + +enum class DataType { + FP16, + BF16 +}; + +struct fmha_fwd_traits +{ + DataType dtype; + int head_dim; + int mask_type; // 0: no mask, 1: top_left, 2: bottom_right, 3: sliding window + int bf16_cvt; // 0: rtz, 1: rtna, 2: rtne + int mode; // 0: batch, 1: group +}; + +static unordered_map bf16_cvt_map = { + {0, "_rtz"}, + {1, "_rtna"}, + {2, "_rtne"}, + {3, ""} +}; + +template +class FmhaFwdKernelSelector { +public: + FmhaFwdKernelSelector(fmha_fwd_traits traits) { + file_name = GetFileName(traits); + kernel_name = GetKernelName(traits); + } + + ~FmhaFwdKernelSelector() {} + + string file_name; + string kernel_name; +private: + static string GetMetaNameFromArgs(fmha_fwd_traits traits) { + if constexpr (arch == GPUArch::gfx950) { + traits.bf16_cvt = 3; // no bf16 cvt for gfx950 + } + string meta_name = format("hd{}_{}{}{}{}", + traits.head_dim, + traits.dtype == DataType::FP16 ? "fp16" : "bf16", // FIXME: add more dtypes if needed + traits.mask_type == 2 ? "_causal" : "", + GetBf16Cvt(traits.bf16_cvt), + traits.mode == 0 ? "" : "_group"); + + return meta_name; + } + + static string GetBf16Cvt(int cvt_type) { + if constexpr (arch == GPUArch::gfx950) { + return ""; + } else { + return bf16_cvt_map[cvt_type]; + } + } + + string GetKernelName(fmha_fwd_traits traits) { + string meta_name = GetMetaNameFromArgs(traits); + string length = to_string(meta_name.length()); + return format("ZN5aiter{}fmha_fwd_{}E", length, meta_name); + } + + string GetFileName(fmha_fwd_traits traits) { + string meta_name = GetMetaNameFromArgs(traits); + if constexpr (arch == GPUArch::gfx950) { + return format("fmha_v3_fwd/fwd_{}.co", GetMetaNameFromArgs(traits)); + } else { + return format("fmha_v3_fwd/{}/fwd_{}.co", GetCUDir(), GetMetaNameFromArgs(traits)); + } + } + + string GetCUDir() { + uint32_t cu_num = get_num_cu_func(); + if (cu_num == 304) { + return "MI300"; + } else if (cu_num == 80 || cu_num == 64) { + return "MI308"; + } else { + std::cout << cu_num << std::endl; + return {}; + } + } +}; +} // namespace aiter + +template +class FmhaFwdKernelDispatcher { +public: + FmhaFwdKernelDispatcher(aiter::fmha_fwd_traits traits, KernelList list) { + // Dispatch logic to select the appropriate kernel from KernelList based on traits + // This is a placeholder for actual dispatch logic + FmhaFwdKernelSelector ks(traits); + selected_kernel = ks.file_name; + + + if (list.find(selected_kernel) != list.end()) { + std::cout << "Selected kernel: " << selected_kernel << std::endl; + AiterAsmKernel *k(ks.file_name, ks.kernel_name); + } else { + std::cout << "No suitable kernel found for the given traits." << std::endl; + } + } + + ~FmhaFwdKernelDispatcher() {} + + string selected_kernel; + AiterAsmKernel k; +}; + +int main () { + aiter::fmha_fwd_traits traits; + traits.head_dim = 128; + traits.dtype = aiter::DataType::BF16; + traits.mask_type = 2; + traits.bf16_cvt = 1; + traits.mode = 0; + + aiter::FmhaFwdKernelSelector ks(traits); + std::cout << "kernel_name: " << ks.kernel_name << std::endl; + std::cout << "file_name: " << ks.file_name << std::endl; + + return 0; +} diff --git a/hsa/gfx942/fmha_v3_fwd/toy_format.hpp b/hsa/gfx942/fmha_v3_fwd/toy_format.hpp new file mode 100644 index 000000000..025706720 --- /dev/null +++ b/hsa/gfx942/fmha_v3_fwd/toy_format.hpp @@ -0,0 +1,48 @@ +#ifndef TOY_FORMAT_HPP +#define TOY_FORMAT_HPP + +#include +#include +#include + +namespace std { + +namespace detail { + +template +std::string to_string_helper(const T& v) { + std::ostringstream os; + os << v; + return os.str(); +} + +inline std::string format_impl(std::string_view fmt) { + return std::string(fmt); +} + +template +std::string format_impl(std::string_view fmt, + const Arg& first, + const Args&... rest) { + std::size_t pos = fmt.find("{}"); + if (pos == std::string_view::npos) + throw std::runtime_error("extra argument provided to format"); + + return std::string(fmt.substr(0, pos)) + + to_string_helper(first) + + format_impl(fmt.substr(pos + 2), rest...); +} + +} // namespace detail + +template +std::string format(std::string_view fmt, const Args&... args) { + std::string result = detail::format_impl(fmt, args...); + if (result.find("{}") != std::string::npos) + throw std::runtime_error("too few arguments provided to format"); + return result; +} + +} // namespace std + +#endif // TOY_FORMAT_HPP diff --git a/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dq_convert.csv b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dq_convert.csv new file mode 100644 index 000000000..fbe3051e8 --- /dev/null +++ b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dq_convert.csv @@ -0,0 +1,13 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +fp16,64,64,0,0,0,0,0,3,0,64,_ZN5aiter29fmha_bwd_hd64_dq_convert_fp16E,bwd_hd64_dq_convert_fp16.co +fp16,64,64,0,0,0,0,1,3,0,64,_ZN5aiter35fmha_bwd_hd64_dq_convert_fp16_groupE,bwd_hd64_dq_convert_fp16_group.co +bf16,64,64,0,0,0,0,0,3,0,64,_ZN5aiter29fmha_bwd_hd64_dq_convert_bf16E,bwd_hd64_dq_convert_bf16.co +bf16,64,64,0,0,0,0,1,3,0,64,_ZN5aiter35fmha_bwd_hd64_dq_convert_bf16_groupE,bwd_hd64_dq_convert_bf16_group.co +fp16,128,128,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd128_dq_convert_fp16E,bwd_hd128_dq_convert_fp16.co +fp16,128,128,0,0,0,0,1,3,0,64,_ZN5aiter36fmha_bwd_hd128_dq_convert_fp16_groupE,bwd_hd128_dq_convert_fp16_group.co +bf16,128,128,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd128_dq_convert_bf16E,bwd_hd128_dq_convert_bf16.co +bf16,128,128,0,0,0,0,1,3,0,64,_ZN5aiter36fmha_bwd_hd128_dq_convert_bf16_groupE,bwd_hd128_dq_convert_bf16_group.co +fp16,192,192,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd192_dq_convert_fp16E,bwd_hd192_dq_convert_fp16.co +fp16,192,192,0,0,0,0,1,3,0,64,_ZN5aiter36fmha_bwd_hd192_dq_convert_fp16_groupE,bwd_hd192_dq_convert_fp16_group.co +bf16,192,192,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd192_dq_convert_bf16E,bwd_hd192_dq_convert_bf16.co +bf16,192,192,0,0,0,0,1,3,0,64,_ZN5aiter36fmha_bwd_hd192_dq_convert_bf16_groupE,bwd_hd192_dq_convert_bf16_group.co diff --git a/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dq_shuffle.csv b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dq_shuffle.csv new file mode 100644 index 000000000..c8b613b8b --- /dev/null +++ b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dq_shuffle.csv @@ -0,0 +1,5 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +bf16,192,192,0,0,0,0,0,3,0,64,_ZN5aiter25fmha_bwd_hd192_dq_shuffleE,bwd_hd192_dq_shuffle.co +bf16,192,192,0,0,0,0,1,3,0,64,_ZN5aiter31fmha_bwd_hd192_dq_shuffle_groupE,bwd_hd192_dq_shuffle_group.co +bf16,128,128,0,0,0,0,0,3,0,64,_ZN5aiter25fmha_bwd_hd128_dq_shuffleE,bwd_hd128_dq_shuffle.co +bf16,128,128,0,0,0,0,1,3,0,64,_ZN5aiter31fmha_bwd_hd128_dq_shuffle_groupE,bwd_hd128_dq_shuffle_group.co diff --git a/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dqdkdv.csv b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dqdkdv.csv new file mode 100644 index 000000000..612ef5657 --- /dev/null +++ b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_dqdkdv.csv @@ -0,0 +1,33 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +fp16,128,128,0,0,1,1,0,3,16,256,_ZN5aiter31fmha_bwd_hd128_fp16_a16_psskddvE,bwd_hd128_fp16_a16_psskddv.co +fp16,128,128,0,0,1,1,1,3,16,256,_ZN5aiter37fmha_bwd_hd128_fp16_a16_psskddv_groupE,bwd_hd128_fp16_a16_psskddv_group.co +bf16,128,128,0,0,1,1,0,3,16,256,_ZN5aiter31fmha_bwd_hd128_bf16_a16_psskddvE,bwd_hd128_bf16_a16_psskddv.co +bf16,128,128,0,0,1,1,1,3,16,256,_ZN5aiter37fmha_bwd_hd128_bf16_a16_psskddv_groupE,bwd_hd128_bf16_a16_psskddv_group.co +fp16,192,128,0,0,1,1,0,3,16,192,_ZN5aiter32fmha_bwd_hd192_128_fp16_a16_psskE,bwd_hd192_128_fp16_a16_pssk.co +bf16,192,128,0,0,1,1,0,3,16,192,_ZN5aiter32fmha_bwd_hd192_128_bf16_a16_psskE,bwd_hd192_128_bf16_a16_pssk.co +fp16,192,128,1,1,1,1,0,3,16,192,_ZN5aiter39fmha_bwd_hd192_128_fp16_causal_a32_psskE,bwd_hd192_128_fp16_causal_a32_pssk.co +bf16,192,128,1,1,1,1,0,3,16,192,_ZN5aiter39fmha_bwd_hd192_128_bf16_causal_a32_psskE,bwd_hd192_128_bf16_causal_a32_pssk.co +fp16,128,128,1,1,1,1,0,3,16,256,_ZN5aiter38fmha_bwd_hd128_fp16_causal_a32_psskddvE,bwd_hd128_fp16_causal_a32_psskddv.co +fp16,128,128,1,1,1,1,1,3,16,256,_ZN5aiter44fmha_bwd_hd128_fp16_causal_a32_psskddv_groupE,bwd_hd128_fp16_causal_a32_psskddv_group.co +bf16,128,128,1,1,1,1,0,3,16,256,_ZN5aiter38fmha_bwd_hd128_bf16_causal_a32_psskddvE,bwd_hd128_bf16_causal_a32_psskddv.co +bf16,128,128,1,1,1,1,1,3,16,256,_ZN5aiter44fmha_bwd_hd128_bf16_causal_a32_psskddv_groupE,bwd_hd128_bf16_causal_a32_psskddv_group.co +fp16,128,128,0,1,1,1,0,3,16,256,_ZN5aiter31fmha_bwd_hd128_fp16_a32_psskddvE,bwd_hd128_fp16_a32_psskddv.co +fp16,128,128,0,1,1,1,1,3,16,256,_ZN5aiter37fmha_bwd_hd128_fp16_a32_psskddv_groupE,bwd_hd128_fp16_a32_psskddv_group.co +bf16,128,128,0,1,1,1,0,3,16,256,_ZN5aiter31fmha_bwd_hd128_bf16_a32_psskddvE,bwd_hd128_bf16_a32_psskddv.co +bf16,128,128,0,1,1,1,1,3,16,256,_ZN5aiter37fmha_bwd_hd128_bf16_a32_psskddv_groupE,bwd_hd128_bf16_a32_psskddv_group.co +fp16,192,128,0,1,1,1,0,3,16,192,_ZN5aiter32fmha_bwd_hd192_128_fp16_a32_psskE,bwd_hd192_128_fp16_a32_pssk.co +bf16,192,128,0,1,1,1,0,3,16,192,_ZN5aiter32fmha_bwd_hd192_128_bf16_a32_psskE,bwd_hd192_128_bf16_a32_pssk.co +fp16,128,128,2,0,1,1,0,3,16,256,_ZN5aiter41fmha_bwd_hd128_fp16_causal_br_a16_psskddvE,bwd_hd128_fp16_causal_br_a16_psskddv.co +fp16,128,128,2,0,1,1,1,3,16,256,_ZN5aiter47fmha_bwd_hd128_fp16_causal_br_a16_psskddv_groupE,bwd_hd128_fp16_causal_br_a16_psskddv_group.co +bf16,128,128,2,0,1,1,0,3,16,256,_ZN5aiter41fmha_bwd_hd128_bf16_causal_br_a16_psskddvE,bwd_hd128_bf16_causal_br_a16_psskddv.co +bf16,128,128,2,0,1,1,1,3,16,256,_ZN5aiter47fmha_bwd_hd128_bf16_causal_br_a16_psskddv_groupE,bwd_hd128_bf16_causal_br_a16_psskddv_group.co +fp16,128,128,2,1,1,1,0,3,16,256,_ZN5aiter41fmha_bwd_hd128_fp16_causal_br_a32_psskddvE,bwd_hd128_fp16_causal_br_a32_psskddv.co +fp16,128,128,2,1,1,1,1,3,16,256,_ZN5aiter47fmha_bwd_hd128_fp16_causal_br_a32_psskddv_groupE,bwd_hd128_fp16_causal_br_a32_psskddv_group.co +bf16,128,128,2,1,1,1,0,3,16,256,_ZN5aiter41fmha_bwd_hd128_bf16_causal_br_a32_psskddvE,bwd_hd128_bf16_causal_br_a32_psskddv.co +bf16,128,128,2,1,1,1,1,3,16,256,_ZN5aiter47fmha_bwd_hd128_bf16_causal_br_a32_psskddv_groupE,bwd_hd128_bf16_causal_br_a32_psskddv_group.co +fp16,192,128,1,0,1,1,0,3,16,192,_ZN5aiter39fmha_bwd_hd192_128_fp16_causal_a16_psskE,bwd_hd192_128_fp16_causal_a16_pssk.co +bf16,192,128,1,0,1,1,0,3,16,192,_ZN5aiter39fmha_bwd_hd192_128_bf16_causal_a16_psskE,bwd_hd192_128_bf16_causal_a16_pssk.co +fp16,128,128,1,0,1,1,0,3,16,256,_ZN5aiter38fmha_bwd_hd128_fp16_causal_a16_psskddvE,bwd_hd128_fp16_causal_a16_psskddv.co +fp16,128,128,1,0,1,1,1,3,16,256,_ZN5aiter44fmha_bwd_hd128_fp16_causal_a16_psskddv_groupE,bwd_hd128_fp16_causal_a16_psskddv_group.co +bf16,128,128,1,0,1,1,0,3,16,256,_ZN5aiter38fmha_bwd_hd128_bf16_causal_a16_psskddvE,bwd_hd128_bf16_causal_a16_psskddv.co +bf16,128,128,1,0,1,1,1,3,16,256,_ZN5aiter44fmha_bwd_hd128_bf16_causal_a16_psskddv_groupE,bwd_hd128_bf16_causal_a16_psskddv_group.co diff --git a/hsa/gfx950/fmha_v3_bwd/fmha_bwd_odo.csv b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_odo.csv new file mode 100644 index 000000000..c21eeff46 --- /dev/null +++ b/hsa/gfx950/fmha_v3_bwd/fmha_bwd_odo.csv @@ -0,0 +1,13 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +fp16,64,64,0,0,0,0,0,3,0,128,_ZN5aiter22fmha_bwd_hd64_odo_fp16E,bwd_hd64_odo_fp16.co +fp16,64,64,0,0,0,0,1,3,0,128,_ZN5aiter28fmha_bwd_hd64_odo_fp16_groupE,bwd_hd64_odo_fp16_group.co +bf16,64,64,0,0,0,0,0,3,0,128,_ZN5aiter22fmha_bwd_hd64_odo_bf16E,bwd_hd64_odo_bf16.co +bf16,64,64,0,0,0,0,1,3,0,128,_ZN5aiter28fmha_bwd_hd64_odo_bf16_groupE,bwd_hd64_odo_bf16_group.co +fp16,128,128,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd128_odo_fp16E,bwd_hd128_odo_fp16.co +fp16,128,128,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd128_odo_fp16_groupE,bwd_hd128_odo_fp16_group.co +bf16,128,128,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd128_odo_bf16E,bwd_hd128_odo_bf16.co +bf16,128,128,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd128_odo_bf16_groupE,bwd_hd128_odo_bf16_group.co +fp16,192,192,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd192_odo_fp16E,bwd_hd192_odo_fp16.co +fp16,192,192,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd192_odo_fp16_groupE,bwd_hd192_odo_fp16_group.co +bf16,192,192,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd192_odo_bf16E,bwd_hd192_odo_bf16.co +bf16,192,192,0,0,0,0,1,3,0,128,_ZN5aiter29fmha_bwd_hd192_odo_bf16_groupE,bwd_hd192_odo_bf16_group.co diff --git a/op_tests/cpp/mha/benchmark_mha_bwd.cpp b/op_tests/cpp/mha/benchmark_mha_bwd.cpp index aaee36f7e..e58fd7eff 100644 --- a/op_tests/cpp/mha/benchmark_mha_bwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_bwd.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" +#include "fmha_bwd.hpp" #include "mha_bwd.h" #include "utils.hpp" @@ -148,7 +149,7 @@ auto create_args(int argc, char* argv[]) "0", "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " "will not be used") - .insert("bwd_v3", "0", "if set to 1, some cases will call the bwd v3 dqdkdv kernel") + .insert("bwd_v3", "0", "0: default(asm first) 1: force asm 2: force ck") .insert( "v3_atomic_fp32", "1", @@ -156,7 +157,7 @@ auto create_args(int argc, char* argv[]) .insert("v3_bf16_cvt", "1", "float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ") - .insert("is_v3_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel."); + .insert("v3_api_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -263,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bool bwd_v3 = arg_parser.get_bool("bwd_v3"); bool v3_atomic_fp32 = arg_parser.get_bool("v3_atomic_fp32"); int v3_bf16_cvt = arg_parser.get_int("v3_bf16_cvt"); - bool is_v3_check = arg_parser.get_bool("is_v3_check"); + bool v3_api_check = arg_parser.get_bool("v3_api_check"); ck_tile::stream_config stream_config{nullptr, true, @@ -413,9 +414,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-1.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution{-1.f, 1.f, seed}(bias_host); ck_tile::FillUniformDistribution{-1.f, 1.f, seed}(do_host); - // ck_tile::FillConstant{1}(q_host); - // ck_tile::FillConstant{1}(k_host); - // ck_tile::FillConstant{2}(do_host); + ck_tile::FillConstant{1}(q_host); + ck_tile::FillConstant{1}(k_host); + ck_tile::FillConstant{1}(v_host); + ck_tile::FillConstant{1}(do_host); } else if(init_method == 2) { @@ -493,7 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval - << ", deterministic:" << deterministic << ", mask:" << mask << std::flush; + << ", deterministic:" << deterministic << ", mask:" << mask << std::flush << std::endl; std::size_t workspace_size = dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024); @@ -504,7 +506,24 @@ bool run(const ck_tile::ArgParser& arg_parser) << " MByte memory workspace allocated" << std::endl; } - auto fmha_args = [&]() { + auto get_mask_type = [&]() { + if (mask.type == mask_enum::no_mask) { + return 0; + } else { + if (mask.type == mask_enum::window_generic) { + assert(false); + return 0; + } else { + if ((mask.left == -1) && (mask.right == 0)) { + return (mask.type == mask_enum::mask_top_left) ? 1 : 2; + } else { + return 3; + } + } + } + }; + + auto mha_args = [&]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -561,7 +580,33 @@ bool run(const ck_tile::ArgParser& arg_parser) } }(); - return fmha_bwd_args{q_buf.GetDeviceBuffer(), + return aiter::mha_bwd_args{nullptr, // stream_id_ + true, // time_kernel_ + /* log_level = */ (kname ? 1 : 0), // log_level_ + stream_warmup, // cold_niters_ + stream_repeat, // nrepeat_ + arg_parser.get_str("timer") == std::string("gpu"), // is_gpu_timer_ + false, // flush_cache_ + 1, // rotating_count_ + + get_mask_type(), + bwd_v3, + v3_atomic_fp32, + v3_bf16_cvt, + v3_api_check, + + hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + static_cast(mask.type), + static_cast(bias.type), + use_dbias, + p_drop > 0, + s_randval, + deterministic, + + q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() @@ -587,8 +632,6 @@ bool run(const ck_tile::ArgParser& arg_parser) batch, max_seqlen_q, max_seqlen_k, - hdim_q, - hdim_v, nhead, nhead_k, scale, @@ -634,30 +677,15 @@ bool run(const ck_tile::ArgParser& arg_parser) split_stride_dq_acc, mask.left, mask.right, - static_cast(mask.type), p_drop, p_undrop, drop_seed_offset}; }(); - float ave_time = aiter::mha_bwd(fmha_args, - stream_config, - data_type, - mode == mode_enum::group, - mask.type, - bias.type, - use_dbias, - s_randval, - deterministic, - bwd_v3, - v3_atomic_fp32, - v3_bf16_cvt, - nullptr, - nullptr, - is_v3_check); + float ave_time = aiter::mha_bwd(mha_args); if(ave_time < 0) { - std::cout << ", not supported yet" << std::flush << std::endl; + std::cout << "not supported yet" << std::flush << std::endl; return false; } @@ -665,7 +693,7 @@ bool run(const ck_tile::ArgParser& arg_parser) float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + std::cout << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush; @@ -891,18 +919,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; - aiter::mha_bwd(fmha_args, - stream_config_v, - data_type, - mode == mode_enum::group, - mask.type, - bias.type, - use_dbias, - s_randval, - deterministic, - bwd_v3, - v3_atomic_fp32, - v3_bf16_cvt); + aiter::mha_bwd(mha_args); dq_buf.FromDevice(dq_host.data()); dk_buf.FromDevice(dk_host.data()); diff --git a/op_tests/cpp/mha/compile.py b/op_tests/cpp/mha/compile.py index 6c84f8485..80213d095 100644 --- a/op_tests/cpp/mha/compile.py +++ b/op_tests/cpp/mha/compile.py @@ -13,17 +13,14 @@ FWD_CODEGEN_CMD = [] -BWD_CODEGEN_CMD = [] +BWD_CODEGEN_CMD = [f"{AITER_META_DIR}/hsa/codegen.py -m fmha_v3_bwd --output_dir {{}}"] def get_asm_dir(): for gfx in get_gfx_list(): FWD_ASM_DIR = f"{AITER_META_DIR}/hsa/{gfx}/fmha_v3_fwd" - BWD_ASM_DIR = f"{AITER_META_DIR}/hsa/{gfx}/fmha_v3_bwd" if os.path.exists(FWD_ASM_DIR): FWD_CODEGEN_CMD.append(f"{FWD_ASM_DIR}/codegen.py --output_dir {{}}") - if os.path.exists(BWD_ASM_DIR): - BWD_CODEGEN_CMD.append(f"{BWD_ASM_DIR}/codegen.py --output_dir {{}}") def cmdGenFunc_mha_fwd(ck_exclude: bool): @@ -55,10 +52,7 @@ def compile_mha_fwd(ck_exclude: bool): ... def cmdGenFunc_mha_bwd(ck_exclude: bool): if ck_exclude: - blob_gen_cmd = [ - f'{AITER_CSRC_DIR}/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py --filter "*@*_ndeterministic@*_nbias*_dropout*_ndeterministic*" --output_dir {{}}', - f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 2 --output_dir {{}}", - ] + blob_gen_cmd = [] else: blob_gen_cmd = [ f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd --receipt 600 --output_dir {{}}", @@ -66,6 +60,7 @@ def cmdGenFunc_mha_bwd(ck_exclude: bool): f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 3 --output_dir {{}}", ] blob_gen_cmd.extend(BWD_CODEGEN_CMD) + print(blob_gen_cmd) return { "md_name": "libmha_bwd", "blob_gen_cmd": blob_gen_cmd,