Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/ut/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def test_forward_oot_1d_positions(self, mock_cpu_arc, mock_npu_mrope):

@patch('torch_npu.npu_mrope')
@patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture')
@patch('vllm.triton_utils.HAS_TRITON', False)
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
Expand All @@ -469,3 +470,25 @@ def test_forward_oot_2d_positions(self, mock_cpu_arc, mock_npu_mrope):
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)

@patch('vllm.model_executor.layers.rotary_embedding.mrope.triton_mrope')
@patch('vllm.triton_utils.HAS_TRITON', True)
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.triton_utils.HAS_TRITON', return_value=True)
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_triton_2d_positions(self, mock_triton_mrope):

mock_triton_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))

vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_2d, self.query, self.key)

mock_triton_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)
71 changes: 71 additions & 0 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this import to L31 under if HAS_TRITON

from vllm.platforms import CpuArchEnum
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
import torch_npu._inductor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why import torch_npu._inductor here?


from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
Expand Down Expand Up @@ -402,12 +407,78 @@ def forward(self,

class AscendMRotaryEmbedding(MRotaryEmbedding):

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: list[int] | None = None,
mrope_interleaved: bool = False,
*,
scaling_factor: float | None = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.cos = None
self.sin = None
extra_kwargs = {
"scaling_factor": scaling_factor,
"extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor,
"beta_fast": beta_fast,
"beta_slow": beta_slow
}
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype, mrope_section,
mrope_interleaved, **extra_kwargs)

def forward_triton(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None):
assert positions.ndim == 2
assert key is not None

self._match_cos_sin_cache_dtype(query)

if self.cos is None and self.sin is None:
cos_sin = self.cos_sin_cache[positions] # type: ignore
cos, sin = cos_sin.chunk(2, dim=-1)
self.cos = cos.contiguous()
self.sin = sin.contiguous()
query_shape = query.shape
key_shape = key.shape

assert self.mrope_section

q, k = triton_mrope(
query,
key,
self.cos,
self.sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)

return q.reshape(query_shape), k.reshape(key_shape)

def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
):
if HAS_TRITON and positions.ndim == 2:
# todo: need cann update
return self.forward_triton(positions, query, key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rely on CANN 8.5, please make it backward capability with CANN8.3RC2

# TODO: This judgment will be removed once the mrope precision issue is fixed
if self.mrope_section != [
16, 24, 24
Expand Down
Loading