-
Notifications
You must be signed in to change notification settings - Fork 158
mha api refactor #1573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
mha api refactor #1573
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the Multi-Head Attention (MHA) backward pass API, consolidating arguments into a unified structure and simplifying the kernel selection and invocation process.
Key Changes
- Unified MHA backward arguments into a single
mha_bwd_argsstructure that encapsulates stream configuration, data types, and all tensor pointers - Introduced new CSV-based kernel configuration files for ASM kernel selection (odo, dqdkdv, dq_convert)
- Added new C++ implementation for v3 backward pass with heuristic kernel selection
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/cpp/mha/compile.py | Updated code generation commands for backward pass compilation |
| op_tests/cpp/mha/benchmark_mha_bwd.cpp | Simplified mha_bwd function call with new unified args structure |
| hsa/gfx942/fmha_v3_fwd/toy_format.hpp | Added custom string formatting utility for kernel name generation |
| hsa/gfx942/fmha_v3_fwd/fmha_v3_fwd.cpp | Added forward pass kernel selector with file/kernel name generation |
| hsa/gfx942/fmha_v3_fwd/aiter_hip_common.h | Added common HIP utilities and ASM kernel wrapper classes |
| hsa/gfx942/fmha_v3_bwd/*.csv | Added kernel configuration tables for backward pass kernels |
| csrc/include/mha_bwd.h | Refactored API with new args structure and detailed documentation |
| csrc/cpp_itfs/fmha_v3_bwd.cpp | Implemented new backward pass with heuristic kernel selection logic |
| aiter/jit/optCompilerConfig.json | Updated libmha_bwd configuration to include new source file |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "1", | ||
| "float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ") | ||
| .insert("is_v3_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel."); | ||
| .insert("v3_api_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel."); |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'scenarios' to 'scenario' for grammatical consistency.
| .insert("v3_api_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel."); | |
| .insert("v3_api_check", "0", "if set to 1, check whether the input scenario is supported by the asm kernel."); |
|
|
||
| if (list.find(selected_kernel) != list.end()) { | ||
| std::cout << "Selected kernel: " << selected_kernel << std::endl; | ||
| AiterAsmKernel *k(ks.file_name, ks.kernel_name); |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect pointer declaration syntax. Should use 'new' or remove the pointer declarator '*'.
| AiterAsmKernel *k(ks.file_name, ks.kernel_name); | |
| AiterAsmKernel k(ks.file_name, ks.kernel_name); |
| ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt)) && cfg.mode == mode) | ||
| { | ||
| if(ts_kv == 0) | ||
| { |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable 'ts_kv' is used before being initialized (set to 0 at line 64, but used in condition before being assigned from config).
| { | |
| { | |
| ts_kv = cfg.ts_kv; |
| odo_args.ptr_qseq = a.qseq_ptr; | ||
| odo_args.ptr_qseq_padded = a.qseq_padded_ptr; |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fields 'qseq_ptr' and 'qseq_padded_ptr' do not exist in the 'mha_bwd_args' struct definition. Should be 'seqstart_q_ptr' and equivalent for padded.
| odo_args.ptr_qseq = a.qseq_ptr; | |
| odo_args.ptr_qseq_padded = a.qseq_padded_ptr; | |
| odo_args.ptr_qseq = a.seqstart_q_ptr; | |
| odo_args.ptr_qseq_padded = a.seqstart_q_padded_ptr; |
|
|
||
| int bdx = 256; | ||
| int gdx = (a.seqlen_q + ts_odo - 1) / ts_odo; | ||
| int gdy = a.nhead; |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Field 'nhead' does not exist in 'mha_bwd_args' struct. Should be 'nhead_q'.
| int gdy = a.nhead; | |
| int gdy = a.nhead_q; |
| impl_ptr_pre->launch_kernel({&pre_args, | ||
| &arg_size, | ||
| gdx, | ||
| gdy, | ||
| gdz, | ||
| bdx, | ||
| 1, | ||
| 1, | ||
| a.stream.stream_id_, | ||
| NULL, | ||
| reinterpret_cast<void**>(&config)}); |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using 'impl_ptr_pre' and 'pre_args' in dqdkdv_kernel_launch function. Should use 'impl_ptr_dqdkdv' and 'dqdkdv_args'.
| impl_ptr_pre->launch_kernel({&pre_args, | |
| &arg_size, | |
| gdx, | |
| gdy, | |
| gdz, | |
| bdx, | |
| 1, | |
| 1, | |
| a.stream.stream_id_, | |
| NULL, | |
| reinterpret_cast<void**>(&config)}); | |
| impl_ptr_dqdkdv->launch_kernel({&dqdkdv_args, | |
| &arg_size, | |
| gdx, | |
| gdy, | |
| gdz, | |
| bdx, | |
| 1, | |
| 1, | |
| a.stream.stream_id_, | |
| NULL, | |
| reinterpret_cast<void**>(&config)}); |
| int gdy = a.nhead; | ||
| int gdz = a.batch; | ||
|
|
||
| impl_ptr_pre->launch_kernel({&pre_args, |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using 'impl_ptr_pre' and 'pre_args' in post_kernel_launch function. Should use 'impl_ptr_post' and 'post_args'.
| impl_ptr_pre->launch_kernel({&pre_args, | |
| impl_ptr_post->launch_kernel({&post_args, |
| post_args.ptr_qseq = a.qseq_ptr; | ||
| post_args.ptr_qseq_padded = a.qseq_padded_ptr; |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fields 'qseq_ptr' and 'qseq_padded_ptr' do not exist in 'mha_bwd_args' struct. Should be 'seqstart_q_ptr' and equivalent for padded.
| post_args.ptr_qseq = a.qseq_ptr; | |
| post_args.ptr_qseq_padded = a.qseq_padded_ptr; | |
| post_args.ptr_qseq = a.seqstart_q_ptr; | |
| post_args.ptr_qseq_padded = a.seqstart_q_padded_ptr; |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist