Skip to content

Conversation

@slippedJim
Copy link
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings December 5, 2025 10:13
@slippedJim slippedJim marked this pull request as draft December 5, 2025 10:14
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the Multi-Head Attention (MHA) backward pass API, consolidating arguments into a unified structure and simplifying the kernel selection and invocation process.

Key Changes

  • Unified MHA backward arguments into a single mha_bwd_args structure that encapsulates stream configuration, data types, and all tensor pointers
  • Introduced new CSV-based kernel configuration files for ASM kernel selection (odo, dqdkdv, dq_convert)
  • Added new C++ implementation for v3 backward pass with heuristic kernel selection

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
op_tests/cpp/mha/compile.py Updated code generation commands for backward pass compilation
op_tests/cpp/mha/benchmark_mha_bwd.cpp Simplified mha_bwd function call with new unified args structure
hsa/gfx942/fmha_v3_fwd/toy_format.hpp Added custom string formatting utility for kernel name generation
hsa/gfx942/fmha_v3_fwd/fmha_v3_fwd.cpp Added forward pass kernel selector with file/kernel name generation
hsa/gfx942/fmha_v3_fwd/aiter_hip_common.h Added common HIP utilities and ASM kernel wrapper classes
hsa/gfx942/fmha_v3_bwd/*.csv Added kernel configuration tables for backward pass kernels
csrc/include/mha_bwd.h Refactored API with new args structure and detailed documentation
csrc/cpp_itfs/fmha_v3_bwd.cpp Implemented new backward pass with heuristic kernel selection logic
aiter/jit/optCompilerConfig.json Updated libmha_bwd configuration to include new source file

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"1",
"float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ")
.insert("is_v3_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel.");
.insert("v3_api_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel.");
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'scenarios' to 'scenario' for grammatical consistency.

Suggested change
.insert("v3_api_check", "0", "if set to 1, check whether the input scenarios is supported by the asm kernel.");
.insert("v3_api_check", "0", "if set to 1, check whether the input scenario is supported by the asm kernel.");

Copilot uses AI. Check for mistakes.

if (list.find(selected_kernel) != list.end()) {
std::cout << "Selected kernel: " << selected_kernel << std::endl;
AiterAsmKernel *k(ks.file_name, ks.kernel_name);
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect pointer declaration syntax. Should use 'new' or remove the pointer declarator '*'.

Suggested change
AiterAsmKernel *k(ks.file_name, ks.kernel_name);
AiterAsmKernel k(ks.file_name, ks.kernel_name);

Copilot uses AI. Check for mistakes.
((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt)) && cfg.mode == mode)
{
if(ts_kv == 0)
{
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable 'ts_kv' is used before being initialized (set to 0 at line 64, but used in condition before being assigned from config).

Suggested change
{
{
ts_kv = cfg.ts_kv;

Copilot uses AI. Check for mistakes.
Comment on lines +317 to +318
odo_args.ptr_qseq = a.qseq_ptr;
odo_args.ptr_qseq_padded = a.qseq_padded_ptr;
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fields 'qseq_ptr' and 'qseq_padded_ptr' do not exist in the 'mha_bwd_args' struct definition. Should be 'seqstart_q_ptr' and equivalent for padded.

Suggested change
odo_args.ptr_qseq = a.qseq_ptr;
odo_args.ptr_qseq_padded = a.qseq_padded_ptr;
odo_args.ptr_qseq = a.seqstart_q_ptr;
odo_args.ptr_qseq_padded = a.seqstart_q_padded_ptr;

Copilot uses AI. Check for mistakes.

int bdx = 256;
int gdx = (a.seqlen_q + ts_odo - 1) / ts_odo;
int gdy = a.nhead;
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Field 'nhead' does not exist in 'mha_bwd_args' struct. Should be 'nhead_q'.

Suggested change
int gdy = a.nhead;
int gdy = a.nhead_q;

Copilot uses AI. Check for mistakes.
Comment on lines +367 to +377
impl_ptr_pre->launch_kernel({&pre_args,
&arg_size,
gdx,
gdy,
gdz,
bdx,
1,
1,
a.stream.stream_id_,
NULL,
reinterpret_cast<void**>(&config)});
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using 'impl_ptr_pre' and 'pre_args' in dqdkdv_kernel_launch function. Should use 'impl_ptr_dqdkdv' and 'dqdkdv_args'.

Suggested change
impl_ptr_pre->launch_kernel({&pre_args,
&arg_size,
gdx,
gdy,
gdz,
bdx,
1,
1,
a.stream.stream_id_,
NULL,
reinterpret_cast<void**>(&config)});
impl_ptr_dqdkdv->launch_kernel({&dqdkdv_args,
&arg_size,
gdx,
gdy,
gdz,
bdx,
1,
1,
a.stream.stream_id_,
NULL,
reinterpret_cast<void**>(&config)});

Copilot uses AI. Check for mistakes.
int gdy = a.nhead;
int gdz = a.batch;

impl_ptr_pre->launch_kernel({&pre_args,
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using 'impl_ptr_pre' and 'pre_args' in post_kernel_launch function. Should use 'impl_ptr_post' and 'post_args'.

Suggested change
impl_ptr_pre->launch_kernel({&pre_args,
impl_ptr_post->launch_kernel({&post_args,

Copilot uses AI. Check for mistakes.
Comment on lines +392 to +393
post_args.ptr_qseq = a.qseq_ptr;
post_args.ptr_qseq_padded = a.qseq_padded_ptr;
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fields 'qseq_ptr' and 'qseq_padded_ptr' do not exist in 'mha_bwd_args' struct. Should be 'seqstart_q_ptr' and equivalent for padded.

Suggested change
post_args.ptr_qseq = a.qseq_ptr;
post_args.ptr_qseq_padded = a.qseq_padded_ptr;
post_args.ptr_qseq = a.seqstart_q_ptr;
post_args.ptr_qseq_padded = a.seqstart_q_padded_ptr;

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants