|
24 | 24 | from vllm.model_executor.layers.rotary_embedding import ( |
25 | 25 | DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, |
26 | 26 | YaRNScalingRotaryEmbedding) |
| 27 | +from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope |
27 | 28 | from vllm.platforms import CpuArchEnum |
| 29 | +from vllm.triton_utils import HAS_TRITON |
| 30 | + |
| 31 | +if HAS_TRITON: |
| 32 | + import torch_npu._inductor |
28 | 33 |
|
29 | 34 | from vllm_ascend.platform import NPUPlatform |
30 | 35 | from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, |
31 | 36 | get_ascend_device_type) |
32 | 37 |
|
33 | | - |
34 | 38 | def _custom_rotary_embedding_enabled(query, neox_style, head_size): |
35 | 39 | return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( |
36 | 40 | ) |
@@ -402,12 +406,79 @@ def forward(self, |
402 | 406 |
|
403 | 407 | class AscendMRotaryEmbedding(MRotaryEmbedding): |
404 | 408 |
|
| 409 | + def __init__( |
| 410 | + self, |
| 411 | + head_size: int, |
| 412 | + rotary_dim: int, |
| 413 | + max_position_embeddings: int, |
| 414 | + base: float, |
| 415 | + is_neox_style: bool, |
| 416 | + dtype: torch.dtype, |
| 417 | + mrope_section: list[int] | None = None, |
| 418 | + mrope_interleaved: bool = False, |
| 419 | + *, |
| 420 | + scaling_factor: float | None = None, |
| 421 | + extrapolation_factor: float = 1, |
| 422 | + attn_factor: float = 1, |
| 423 | + beta_fast: int = 32, |
| 424 | + beta_slow: int = 1, |
| 425 | + ) -> None: |
| 426 | + self.cos = None |
| 427 | + self.sin = None |
| 428 | + extra_kwargs = { |
| 429 | + "scaling_factor": scaling_factor, |
| 430 | + "extrapolation_factor": extrapolation_factor, |
| 431 | + "attn_factor": attn_factor, |
| 432 | + "beta_fast": beta_fast, |
| 433 | + "beta_slow": beta_slow |
| 434 | + } |
| 435 | + super().__init__(head_size, rotary_dim, max_position_embeddings, base, |
| 436 | + is_neox_style, dtype, mrope_section, mrope_interleaved, **extra_kwargs) |
| 437 | + |
| 438 | + def forward_triton( |
| 439 | + self, |
| 440 | + positions: torch.Tensor, |
| 441 | + query: torch.Tensor, |
| 442 | + key: torch.Tensor | None = None, |
| 443 | + offsets: torch.Tensor | None = None |
| 444 | + ): |
| 445 | + assert positions.ndim == 2 |
| 446 | + assert key is not None |
| 447 | + |
| 448 | + self._match_cos_sin_cache_dtype(query) |
| 449 | + |
| 450 | + if self.cos is None and self.sin is None: |
| 451 | + cos_sin = self.cos_sin_cache[positions] |
| 452 | + cos, sin = cos_sin.chunk(2, dim=-1) |
| 453 | + self.cos = cos.contiguous() |
| 454 | + self.sin = sin.contiguous() |
| 455 | + query_shape = query.shape |
| 456 | + key_shape = key.shape |
| 457 | + |
| 458 | + assert self.mrope_section |
| 459 | + |
| 460 | + q, k = triton_mrope( |
| 461 | + query, |
| 462 | + key, |
| 463 | + self.cos, |
| 464 | + self.sin, |
| 465 | + self.mrope_section, |
| 466 | + self.head_size, |
| 467 | + self.rotary_dim, |
| 468 | + self.mrope_interleaved, |
| 469 | + ) |
| 470 | + |
| 471 | + return q.reshape(query_shape), k.reshape(key_shape) |
| 472 | + |
405 | 473 | def forward_oot( |
406 | 474 | self, |
407 | 475 | positions: torch.Tensor, |
408 | 476 | query: torch.Tensor, |
409 | 477 | key: torch.Tensor, |
410 | 478 | ): |
| 479 | + if HAS_TRITON and positions.ndim == 2: |
| 480 | + # todo: need cann update |
| 481 | + return self.forward_triton(positions, query, key) |
411 | 482 | # TODO: This judgment will be removed once the mrope precision issue is fixed |
412 | 483 | if self.mrope_section != [ |
413 | 484 | 16, 24, 24 |
|
0 commit comments