-
Notifications
You must be signed in to change notification settings - Fork 638
[ops]support triton mrope #4668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Triton support for Multi-head Rotary Position Embedding (mrope). My review identified a critical issue in the implementation within AscendMRotaryEmbedding. The cos and sin values for the rotary embeddings are incorrectly cached after the first forward pass, which would lead to incorrect computations for all subsequent batches with different positions. I have provided a detailed comment with a code suggestion to rectify this bug by ensuring these values are recomputed on every forward pass.
vllm_ascend/ops/rotary_embedding.py
Outdated
| def forward_triton( | ||
| self, | ||
| positions: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor | None = None, | ||
| offsets: torch.Tensor | None = None | ||
| ): | ||
| assert positions.ndim == 2 | ||
| assert key is not None | ||
|
|
||
| self._match_cos_sin_cache_dtype(query) | ||
|
|
||
| if self.cos is None and self.sin is None: | ||
| cos_sin = self.cos_sin_cache[positions] | ||
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| self.cos = cos.contiguous() | ||
| self.sin = sin.contiguous() | ||
| query_shape = query.shape | ||
| key_shape = key.shape | ||
|
|
||
| assert self.mrope_section | ||
|
|
||
| q, k = triton_mrope( | ||
| query, | ||
| key, | ||
| self.cos, | ||
| self.sin, | ||
| self.mrope_section, | ||
| self.head_size, | ||
| self.rotary_dim, | ||
| self.mrope_interleaved, | ||
| ) | ||
|
|
||
| return q.reshape(query_shape), k.reshape(key_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation incorrectly caches cos and sin values in self.cos and self.sin during the first forward pass. These values depend on the positions tensor, which can vary between calls. Reusing stale cached values will result in incorrect rotary embeddings for subsequent batches. The cos and sin tensors should be computed on every call to forward_triton and should be local variables, not instance attributes. Consequently, self.cos and self.sin should also be removed from the __init__ method.
def forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None
):
assert positions.ndim == 2
assert key is not None
self._match_cos_sin_cache_dtype(query)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.contiguous()
sin = sin.contiguous()
query_shape = query.shape
key_shape = key.shape
assert self.mrope_section
q, k = triton_mrope(
query,
key,
cos,
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
return q.reshape(query_shape), k.reshape(key_shape)6a3b508 to
07980a4
Compare
Signed-off-by: shiyuan680 <[email protected]>
07980a4 to
4663f09
Compare
| from vllm.model_executor.layers.rotary_embedding import ( | ||
| DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, | ||
| YaRNScalingRotaryEmbedding) | ||
| from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this import to L31 under if HAS_TRITON
| from vllm.triton_utils import HAS_TRITON | ||
|
|
||
| if HAS_TRITON: | ||
| import torch_npu._inductor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why import torch_npu._inductor here?
|
please rebase to main |
| ): | ||
| if HAS_TRITON and positions.ndim == 2: | ||
| # todo: need cann update | ||
| return self.forward_triton(positions, query, key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this rely on CANN 8.5, please make it backward capability with CANN8.3RC2
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops
this triton ops should use cann 8.5.0
Does this PR introduce any user-facing change?
How was this patch tested?
test in qwen3-vl-235b acc textvqa
native 81.82
npu triton 81.58
cuda triton 81.52
performance is equal to ascendc ops