Skip to content

Conversation

@shiyuan680
Copy link
Contributor

@shiyuan680 shiyuan680 commented Dec 3, 2025

What this PR does / why we need it?

this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops
this triton ops should use cann 8.5.0

Does this PR introduce any user-facing change?

How was this patch tested?

test in qwen3-vl-235b acc textvqa
native 81.82
npu triton 81.58
cuda triton 81.52

performance is equal to ascendc ops

@github-actions
Copy link

github-actions bot commented Dec 3, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Triton support for Multi-head Rotary Position Embedding (mrope). My review identified a critical issue in the implementation within AscendMRotaryEmbedding. The cos and sin values for the rotary embeddings are incorrectly cached after the first forward pass, which would lead to incorrect computations for all subsequent batches with different positions. I have provided a detailed comment with a code suggestion to rectify this bug by ensuring these values are recomputed on every forward pass.

Comment on lines 438 to 471
def forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None
):
assert positions.ndim == 2
assert key is not None

self._match_cos_sin_cache_dtype(query)

if self.cos is None and self.sin is None:
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
self.cos = cos.contiguous()
self.sin = sin.contiguous()
query_shape = query.shape
key_shape = key.shape

assert self.mrope_section

q, k = triton_mrope(
query,
key,
self.cos,
self.sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)

return q.reshape(query_shape), k.reshape(key_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation incorrectly caches cos and sin values in self.cos and self.sin during the first forward pass. These values depend on the positions tensor, which can vary between calls. Reusing stale cached values will result in incorrect rotary embeddings for subsequent batches. The cos and sin tensors should be computed on every call to forward_triton and should be local variables, not instance attributes. Consequently, self.cos and self.sin should also be removed from the __init__ method.

    def forward_triton(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
        offsets: torch.Tensor | None = None
    ):
        assert positions.ndim == 2
        assert key is not None

        self._match_cos_sin_cache_dtype(query)

        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        cos = cos.contiguous()
        sin = sin.contiguous()
        query_shape = query.shape
        key_shape = key.shape

        assert self.mrope_section

        q, k = triton_mrope(
            query,
            key,
            cos,
            sin,
            self.mrope_section,
            self.head_size,
            self.rotary_dim,
            self.mrope_interleaved,
        )

        return q.reshape(query_shape), k.reshape(key_shape)

@shiyuan680 shiyuan680 changed the title [draft]support triton mrope [ops]support triton mrope Dec 4, 2025
Signed-off-by: shiyuan680 <[email protected]>
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this import to L31 under if HAS_TRITON

from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
import torch_npu._inductor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why import torch_npu._inductor here?

@wangxiyuan
Copy link
Collaborator

please rebase to main

):
if HAS_TRITON and positions.ndim == 2:
# todo: need cann update
return self.forward_triton(positions, query, key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rely on CANN 8.5, please make it backward capability with CANN8.3RC2

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants