diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index f5d4f66336f..a965b015d83 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -450,6 +450,7 @@ def test_forward_oot_1d_positions(self, mock_cpu_arc, mock_npu_mrope): @patch('torch_npu.npu_mrope') @patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture') + @patch('vllm.triton_utils.HAS_TRITON', False) @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) @@ -469,3 +470,25 @@ def test_forward_oot_2d_positions(self, mock_cpu_arc, mock_npu_mrope): self.assertFalse(torch.isnan(result_q).any().item()) self.assertFalse(torch.isnan(result_k).any().item()) self.assertEqual(result_q.shape, self.query.shape) + + @patch('vllm.model_executor.layers.rotary_embedding.mrope.triton_mrope') + @patch('vllm.triton_utils.HAS_TRITON', True) + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.triton_utils.HAS_TRITON', return_value=True) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) + def test_forward_triton_2d_positions(self, mock_triton_mrope): + + mock_triton_mrope.return_value = (torch.zeros_like(self.query), + torch.zeros_like(self.key)) + + vllm_config = self._create_vllm_config() + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward_oot( + self.positions_2d, self.query, self.key) + + mock_triton_mrope.assert_called_once() + self.assertFalse(torch.isnan(result_q).any().item()) + self.assertFalse(torch.isnan(result_k).any().item()) + self.assertEqual(result_q.shape, self.query.shape) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 91a6f09fa1a..c5c9c755104 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -24,7 +24,12 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) +from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope from vllm.platforms import CpuArchEnum +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + import torch_npu._inductor from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, @@ -402,12 +407,78 @@ def forward(self, class AscendMRotaryEmbedding(MRotaryEmbedding): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: list[int] | None = None, + mrope_interleaved: bool = False, + *, + scaling_factor: float | None = None, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.cos = None + self.sin = None + extra_kwargs = { + "scaling_factor": scaling_factor, + "extrapolation_factor": extrapolation_factor, + "attn_factor": attn_factor, + "beta_fast": beta_fast, + "beta_slow": beta_slow + } + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype, mrope_section, + mrope_interleaved, **extra_kwargs) + + def forward_triton(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None): + assert positions.ndim == 2 + assert key is not None + + self._match_cos_sin_cache_dtype(query) + + if self.cos is None and self.sin is None: + cos_sin = self.cos_sin_cache[positions] # type: ignore + cos, sin = cos_sin.chunk(2, dim=-1) + self.cos = cos.contiguous() + self.sin = sin.contiguous() + query_shape = query.shape + key_shape = key.shape + + assert self.mrope_section + + q, k = triton_mrope( + query, + key, + self.cos, + self.sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + self.mrope_interleaved, + ) + + return q.reshape(query_shape), k.reshape(key_shape) + def forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, ): + if HAS_TRITON and positions.ndim == 2: + # todo: need cann update + return self.forward_triton(positions, query, key) # TODO: This judgment will be removed once the mrope precision issue is fixed if self.mrope_section != [ 16, 24, 24