-
Notifications
You must be signed in to change notification settings - Fork 629
mlapo add qdown output #4707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mlapo add qdown output #4707
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new 'qdown' output, controlled by the enable_inner_out flag. The changes correctly propagate this new parameter from the PyTorch operator definition down to the kernel implementation, adding new logic paths and updating function signatures accordingly. However, my review identified a critical bug related to duplicated arguments being passed to the kernel implementation, which could lead to incorrect computations. Additionally, there is a significant code duplication issue in the new kernel logic, which impacts maintainability. Addressing these points will improve the correctness and quality of the code.
| qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, inner_out_ptr, workspace_ptr, | ||
| tiling_ptr, block_dim]() -> int { | ||
| mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, | ||
| gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mla_preprocess_impl function is called with sin_ptr and cos_ptr passed for both (sin1, cos1) and (sin2, cos2) arguments. The kernel implementation seems to expect two distinct pairs of sin/cos tensors. Passing the same pointers for both pairs is likely a bug and could lead to incorrect calculations.
The mla_preprocess function signature should probably be updated to accept a second pair of sin/cos tensors, or if this is intentional, it should be clearly documented why the same tensors are passed twice.
| case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: { | ||
| MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, | ||
| QuantMode::PER_TENSOR_ASYMM_QUANT> | ||
| opBf16Cm0Qm0Inner(mlaTilingData, tiling); | ||
| opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, | ||
| quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, | ||
| bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, | ||
| s1, s2, s3, s4, s5, innerOut); | ||
| if ASCEND_IS_AIC { | ||
| opBf16Cm0Qm0Inner.ProcessCube(); | ||
| } | ||
| if ASCEND_IS_AIV { | ||
| opBf16Cm0Qm0Inner.ProcessVector(); | ||
| } | ||
| break; | ||
| } | ||
| case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: { | ||
| MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, | ||
| QuantMode::PER_TENSOR_ASYMM_QUANT> | ||
| opBf16Cm1Qm0Inner(mlaTilingData, tiling); | ||
| opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, | ||
| quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, | ||
| bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, | ||
| s1, s2, s3, s4, s5, innerOut); | ||
| if ASCEND_IS_AIC { | ||
| opBf16Cm1Qm0Inner.ProcessCube(); | ||
| } | ||
| if ASCEND_IS_AIV { | ||
| opBf16Cm1Qm0Inner.ProcessVector(); | ||
| } | ||
| break; | ||
| } | ||
| case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: { | ||
| MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, | ||
| QuantMode::PER_TENSOR_ASYMM_QUANT> | ||
| opBf16Cm3Qm0Inner(mlaTilingData, tiling); | ||
| opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, | ||
| quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, | ||
| bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, | ||
| s1, s2, s3, s4, s5, innerOut); | ||
| if ASCEND_IS_AIC { | ||
| opBf16Cm3Qm0Inner.ProcessCube(); | ||
| } | ||
| if ASCEND_IS_AIV { | ||
| opBf16Cm3Qm0Inner.ProcessVector(); | ||
| } | ||
| break; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication across the new case statements for _INNER keys. The logic inside case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER, case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER, and case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER is identical except for the CacheMode template parameter (0, 1, 3) for MLAPO_BF16_INNER::MLAOperation.
This duplication makes the code harder to read and maintain. Any future changes to this logic will need to be applied in three places, increasing the risk of errors.
To reduce duplication, you could use a helper function templated on the CacheMode to encapsulate the common logic. For example:
template <int CacheMode>
__aicore__ void ProcessInnerOp(MlaTilingData& mlaTilingData, GM_ADDR tiling, /* other params */) {
MLAPO_BF16_INNER::MLAOperation<__bf16, CacheMode, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
op(mlaTilingData, tiling);
op.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
op.ProcessCube();
}
if ASCEND_IS_AIV {
op.ProcessVector();
}
}
// Then in the switch statement:
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER:
ProcessInnerOp<0>(mlaTilingData, tiling, ...);
break;
case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER:
ProcessInnerOp<1>(mlaTilingData, tiling, ...);
break;
case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER:
ProcessInnerOp<3>(mlaTilingData, tiling, ...);
break;Since adding a new function might be a larger change, you could also use a macro to achieve a similar result within the current function body.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
584bc9e to
03aed97
Compare
Signed-off-by: h1074112368 <[email protected]>
Signed-off-by: h1074112368 <[email protected]>
Signed-off-by: h1074112368 <[email protected]>
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: h1074112368 <[email protected]>
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: h1074112368 <[email protected]>
### What this PR does / why we need it? This PR adds mlapo operation support qdown of output. ### Does this PR introduce _any_ user-facing change? mlapo operation add enable_inner_out of input ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: h1074112368 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
### What this PR does / why we need it? This PR adds mlapo operation support qdown of output. ### Does this PR introduce _any_ user-facing change? mlapo operation add enable_inner_out of input ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: h1074112368 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
What this PR does / why we need it?
This PR adds mlapo operation support qdown of output.
Does this PR introduce any user-facing change?
mlapo operation add enable_inner_out of input
How was this patch tested?
CI passed with new added/existing test.