-
Notifications
You must be signed in to change notification settings - Fork 657
[Fusion] [Graph] Add qknorm rope fusion operator #4711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a graph fusion pass for qknorm_rope operations on Ascend hardware, which is a great step for performance optimization. The implementation includes a new configuration flag, a pattern matching pass using torch._inductor.pattern_matcher, and a custom Triton kernel for the fused operation. The code is well-structured, but I've identified several areas for improvement regarding code quality, robustness, and maintainability. My review comments focus on removing debug artifacts, improving code clarity and consistency, enhancing robustness by avoiding hardcoded values and unsafe module-level initializations, and addressing significant code duplication.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
| dtype=self.dtype, | ||
| device=self.device) | ||
| # For GQA models. | ||
| elif not self.vllm_config.model_config.use_mla: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be more careful on this condition. As far as I know, GQA models does not always has rope_dim of 128, and this hardcode might cause some potential bugs. Perhaps we can limit it to qwen3_moe only?
| return q_output, k_output, v_output | ||
|
|
||
|
|
||
| direct_register_custom_op(op_name="qkv_rmsnorm_rope", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use import torch_npu._inductor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern_matcher method of inductor does not support the triton operator. It does support torch.ops.aten (aten operator), torch.ops.npu (custom operator), and torch.add (PyTorch API). Therefore, it is wrapped as a custom op.
| return driver.active.utils.get_device_properties(device) | ||
|
|
||
|
|
||
| num_vectorcore = get_npu_properties()["num_vectorcore"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this parameter has already been defined in triton/utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have modified it.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
bac8b40 to
59f15a7
Compare
|
This pr rely on #4409, because ci has no triton. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
5a890c9 to
98a4d21
Compare
|
|
||
| return q_rope, k_rope, v | ||
|
|
||
| def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pattern in 'if xxx else: torch.ops.vllm.qkv_rmsnorm_rope ’ need support in future releases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't perform any special checks in the pattern. You can add a new pattern match.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
3d8050a to
95718dd
Compare
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
bb324de to
ec6b0df
Compare
Signed-off-by: wxsIcey <[email protected]>
What this PR does / why we need it?
This PR add
qkv_rmsnorm_ropeoperator and introduces a graph fusion pass forqknorm_ropeoperations. The implementation includes a new configuration flag, a pattern matching pass usingtorch._inductor.pattern_matcher, and a custom Triton kernel for the fused operation.Co-authored-by: Angazenn [email protected]
Does this PR introduce any user-facing change?
Yes, add new additional_config
How was this patch tested?
todo