-
Notifications
You must be signed in to change notification settings - Fork 624
mlapo add qdown output #4701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mlapo add qdown output #4701
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a 'qdown' output to the mla_preprocess operation. The changes involve updating kernel definitions, host-side tiling logic, and PyTorch bindings. My review has identified a few issues: a critical bug in torch_binding.cpp where the same sin/cos tensors are used for two different purposes, leading to incorrect calculations. Additionally, there's a typo in a new struct member and significant code duplication in the kernel implementation which should be refactored for better maintainability. The cleanup of unused code in torch_binding.cpp is a good improvement.
csrc/torch_binding.cpp
Outdated
| c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, bool enable_inner_out, at::Tensor &q_out0, | ||
| at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1, at::Tensor &inner_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function only accepts one pair of cos and sin tensors. However, the underlying mla_preprocess_impl kernel is called with these same tensors for two different sets of arguments (sin1/cos1 and sin2/cos2). The kernel implementation uses these for distinct purposes (one for RmsNormAndRopeConvergence1 and another for ropeFp16). This will lead to incorrect rotary embedding calculations. The function signature should be updated to accept two distinct pairs of sin/cos tensors, and the call to mla_preprocess_impl should be updated accordingly.
| uint32_t rmsQuantMin2{0}; | ||
|
|
||
| // Inner | ||
| uint32_t hiddtenState{0}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: { | ||
| MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, | ||
| QuantMode::PER_TENSOR_ASYMM_QUANT> | ||
| opBf16Cm0Qm0Inner(mlaTilingData, tiling); | ||
| opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, | ||
| quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, | ||
| bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, | ||
| s1, s2, s3, s4, s5, innerOut); | ||
| if ASCEND_IS_AIC { | ||
| opBf16Cm0Qm0Inner.ProcessCube(); | ||
| } | ||
| if ASCEND_IS_AIV { | ||
| opBf16Cm0Qm0Inner.ProcessVector(); | ||
| } | ||
| break; | ||
| } | ||
| case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: { | ||
| MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, | ||
| QuantMode::PER_TENSOR_ASYMM_QUANT> | ||
| opBf16Cm1Qm0Inner(mlaTilingData, tiling); | ||
| opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, | ||
| quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, | ||
| bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, | ||
| s1, s2, s3, s4, s5, innerOut); | ||
| if ASCEND_IS_AIC { | ||
| opBf16Cm1Qm0Inner.ProcessCube(); | ||
| } | ||
| if ASCEND_IS_AIV { | ||
| opBf16Cm1Qm0Inner.ProcessVector(); | ||
| } | ||
| break; | ||
| } | ||
| case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: { | ||
| MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND, | ||
| QuantMode::PER_TENSOR_ASYMM_QUANT> | ||
| opBf16Cm3Qm0Inner(mlaTilingData, tiling); | ||
| opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2, | ||
| quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq, | ||
| bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2, | ||
| s1, s2, s3, s4, s5, innerOut); | ||
| if ASCEND_IS_AIC { | ||
| opBf16Cm3Qm0Inner.ProcessCube(); | ||
| } | ||
| if ASCEND_IS_AIV { | ||
| opBf16Cm3Qm0Inner.ProcessVector(); | ||
| } | ||
| break; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The newly added case blocks for KEY_BF16_CACHEMODE_*_INNER contain a significant amount of duplicated code. The logic within each block is almost identical, differing only by the CacheMode template parameter (0, 1, or 3) passed to MLAPO_BF16_INNER::MLAOperation. This code should be refactored into a template helper function to improve maintainability and reduce redundancy. For example:
template <int CacheMode>
void process_inner_op(MlaTilingData& mlaTilingData, GM_ADDR tiling, /* other params */) {
MLAPO_BF16_INNER::MLAOperation<__bf16, CacheMode, ...> op(mlaTilingData, tiling);
op.Init(...);
if ASCEND_IS_AIC {
op.ProcessCube();
}
if ASCEND_IS_AIV {
op.ProcessVector();
}
}
// Then in the switch:
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
process_inner_op<0>(...);
break;
}d269983 to
e53341e
Compare
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?