Skip to content

Commit 6a3b508

Browse files
committed
support triton mrope
Signed-off-by: shiyuan680 <[email protected]>
1 parent 7271f0d commit 6a3b508

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,24 @@ def test_forward_oot_2d_positions(self, mock_cpu_arc, mock_npu_mrope):
469469
self.assertFalse(torch.isnan(result_q).any().item())
470470
self.assertFalse(torch.isnan(result_k).any().item())
471471
self.assertEqual(result_q.shape, self.query.shape)
472+
473+
@patch('vllm.model_executor.layers.rotary_embedding.mrope.triton_mrope')
474+
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
475+
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
476+
@patch('vllm.triton_utils.HAS_TRITON', return_value=True)
477+
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
478+
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
479+
def test_forward_triton_2d_positions(self, mock_triton_mrope):
480+
481+
mock_triton_mrope.return_value = (torch.zeros_like(self.query),
482+
torch.zeros_like(self.key))
483+
484+
vllm_config = self._create_vllm_config()
485+
with set_ascend_forward_context(None, vllm_config):
486+
result_q, result_k = self.layer.forward_oot(
487+
self.positions_2d, self.query, self.key)
488+
489+
mock_triton_mrope.assert_called_once()
490+
self.assertFalse(torch.isnan(result_q).any().item())
491+
self.assertFalse(torch.isnan(result_k).any().item())
492+
self.assertEqual(result_q.shape, self.query.shape)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@
2424
from vllm.model_executor.layers.rotary_embedding import (
2525
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
2626
YaRNScalingRotaryEmbedding)
27+
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
2728
from vllm.platforms import CpuArchEnum
29+
from vllm.triton_utils import HAS_TRITON
30+
31+
if HAS_TRITON:
32+
import torch_npu._inductor
2833

2934
from vllm_ascend.platform import NPUPlatform
3035
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
3136
get_ascend_device_type)
3237

33-
3438
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
3539
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
3640
)
@@ -402,12 +406,79 @@ def forward(self,
402406

403407
class AscendMRotaryEmbedding(MRotaryEmbedding):
404408

409+
def __init__(
410+
self,
411+
head_size: int,
412+
rotary_dim: int,
413+
max_position_embeddings: int,
414+
base: float,
415+
is_neox_style: bool,
416+
dtype: torch.dtype,
417+
mrope_section: list[int] | None = None,
418+
mrope_interleaved: bool = False,
419+
*,
420+
scaling_factor: float | None = None,
421+
extrapolation_factor: float = 1,
422+
attn_factor: float = 1,
423+
beta_fast: int = 32,
424+
beta_slow: int = 1,
425+
) -> None:
426+
self.cos = None
427+
self.sin = None
428+
extra_kwargs = {
429+
"scaling_factor": scaling_factor,
430+
"extrapolation_factor": extrapolation_factor,
431+
"attn_factor": attn_factor,
432+
"beta_fast": beta_fast,
433+
"beta_slow": beta_slow
434+
}
435+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
436+
is_neox_style, dtype, mrope_section, mrope_interleaved, **extra_kwargs)
437+
438+
def forward_triton(
439+
self,
440+
positions: torch.Tensor,
441+
query: torch.Tensor,
442+
key: torch.Tensor | None = None,
443+
offsets: torch.Tensor | None = None
444+
):
445+
assert positions.ndim == 2
446+
assert key is not None
447+
448+
self._match_cos_sin_cache_dtype(query)
449+
450+
if self.cos is None and self.sin is None:
451+
cos_sin = self.cos_sin_cache[positions]
452+
cos, sin = cos_sin.chunk(2, dim=-1)
453+
self.cos = cos.contiguous()
454+
self.sin = sin.contiguous()
455+
query_shape = query.shape
456+
key_shape = key.shape
457+
458+
assert self.mrope_section
459+
460+
q, k = triton_mrope(
461+
query,
462+
key,
463+
self.cos,
464+
self.sin,
465+
self.mrope_section,
466+
self.head_size,
467+
self.rotary_dim,
468+
self.mrope_interleaved,
469+
)
470+
471+
return q.reshape(query_shape), k.reshape(key_shape)
472+
405473
def forward_oot(
406474
self,
407475
positions: torch.Tensor,
408476
query: torch.Tensor,
409477
key: torch.Tensor,
410478
):
479+
if HAS_TRITON and positions.ndim == 2:
480+
# todo: need cann update
481+
return self.forward_triton(positions, query, key)
411482
# TODO: This judgment will be removed once the mrope precision issue is fixed
412483
if self.mrope_section != [
413484
16, 24, 24

0 commit comments

Comments
 (0)