Skip to content

Commit c94c66e

Browse files
committed
register apply_rotary_emb custom op
Signed-off-by: shen-shanshan <[email protected]>
1 parent 38bd952 commit c94c66e

File tree

3 files changed

+72
-27
lines changed

3 files changed

+72
-27
lines changed

vllm_ascend/ops/rotary_embedding.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616
#
1717

1818
import math
19-
from typing import Optional, Tuple
19+
from typing import Callable, Optional, Tuple
2020

21+
import einops
2122
import torch
2223
import torch_npu
2324
from vllm.forward_context import get_forward_context
2425
from vllm.model_executor.layers.rotary_embedding import (
2526
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
2627
YaRNScalingRotaryEmbedding)
28+
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
2729
from vllm.platforms import CpuArchEnum
2830

2931
from vllm_ascend.platform import NPUPlatform
@@ -435,3 +437,40 @@ def forward_oot(
435437
rotary_mode='half')
436438

437439
return query, key
440+
441+
442+
class AscendApplyRotaryEmb(ApplyRotaryEmb):
443+
444+
def __init__(
445+
self,
446+
is_neox_style: bool = False,
447+
is_unsqueeze: bool = False,
448+
default: Callable[..., torch.Tensor] | None = None,
449+
) -> None:
450+
super().__init__(is_neox_style, is_unsqueeze, default)
451+
452+
def forward_oot(
453+
self,
454+
x: torch.Tensor,
455+
cos: torch.Tensor,
456+
sin: torch.Tensor,
457+
) -> torch.Tensor:
458+
# x: [2 * b, s, head, head_dim]
459+
qk = einops.rearrange(
460+
x, "(two b) s head head_dim -> b s two head head_dim", two=2)
461+
# q/k: [b, s, head, head_dim]
462+
q, k = qk[:, :, 0], qk[:, :, 1]
463+
head_dim = q.shape[-1]
464+
465+
cos = torch.cat((cos, cos), dim=-1)
466+
sin = torch.cat((sin, sin), dim=-1)
467+
cos = cos.reshape(1, -1, 1, head_dim)
468+
sin = sin.reshape(1, -1, 1, head_dim)
469+
# cos/sin: [1, s, 1, 2 * head_dim]
470+
471+
q = torch_npu.npu_rotary_mul(q, cos, sin)
472+
k = torch_npu.npu_rotary_mul(k, cos, sin)
473+
474+
# output: []
475+
output = torch.cat([q, k], dim=0)
476+
return output

vllm_ascend/patch/worker/patch_qwen2_5_vl.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
from vllm.model_executor.layers.layernorm import RMSNorm
3333
from vllm.model_executor.layers.quantization import QuantizationConfig
3434
from vllm.model_executor.layers.rotary_embedding import get_rope
35-
from vllm.model_executor.layers.rotary_embedding.common import (
36-
apply_rotary_emb_torch, dispatch_rotary_emb_function)
3735
from vllm.model_executor.models.qwen2_5_vl import (
3836
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
3937
Qwen2_5_VisionPatchMerger, Qwen2_5_VisionTransformer,
@@ -69,36 +67,50 @@ def forward(
6967
x, _ = self.qkv(x)
7068
seq_len, batch_size, _ = x.shape
7169

72-
# Split q k v.
7370
qkv = einops.rearrange(
7471
x,
7572
"s b (three head head_dim) -> b s three head head_dim",
7673
three=3,
7774
head=self.num_attention_heads_per_partition,
7875
)
79-
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
80-
origin_shape = q.shape[-1]
8176

8277
# Convert cumulative tensor to intervals and move it to cpu.
8378
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
8479

85-
cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1)
86-
sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1)
87-
cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head)
88-
sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head)
89-
q = torch_npu.npu_rotary_mul(q, cos, sin)
90-
k = torch_npu.npu_rotary_mul(k, cos, sin)
80+
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
81+
qk, v = qkv[:, :, :2], qkv[:, :, 2]
9182

92-
q, k, v = [
93-
einops.rearrange(x, "b s h d -> (b s) h d").contiguous()
94-
for x in (q, k, v)
95-
]
83+
qk_reshaped = einops.rearrange(
84+
qk, "b s two head head_dim -> (two b) s head head_dim", two=2)
85+
qk_rotated = self.apply_rotary_emb(
86+
qk_reshaped,
87+
rotary_pos_emb_cos,
88+
rotary_pos_emb_sin,
89+
)
90+
qk_rotated = qk_rotated.view(
91+
2,
92+
batch_size,
93+
seq_len,
94+
self.num_attention_heads_per_partition,
95+
self.hidden_size_per_attention_head,
96+
)
97+
q, k = qk_rotated.unbind(dim=0)
98+
else:
99+
q, k, v = qkv.unbind(dim=2)
96100

101+
# TODO(shen-shanshan): Move codes below to MMEncoderAttention CustomOp
102+
# ----------------------------------------------------------------------
97103
enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL
98104
and self.hidden_size_per_attention_head > MIN_PAD_SIZE
99105
and self.hidden_size_per_attention_head < MAX_PAD_SIZE)
100106

107+
q, k, v = [
108+
einops.rearrange(x, "b s h d -> (b s) h d").contiguous()
109+
for x in (q, k, v)
110+
]
111+
101112
if enable_pad:
113+
origin_shape = q.shape[-1]
102114
pad_len = MAX_PAD_SIZE - origin_shape
103115
# q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
104116
q = F.pad(q, (0, pad_len), mode="constant", value=0)
@@ -125,6 +137,7 @@ def forward(
125137
context_layer = einops.rearrange(context_layer,
126138
"(b s) h d -> s b (h d)",
127139
b=batch_size).contiguous()
140+
# ----------------------------------------------------------------------
128141

129142
output, _ = self.proj(context_layer)
130143
return output
@@ -650,14 +663,6 @@ def _process_video_input(
650663
return video_embeds.split(sizes)
651664

652665

653-
def _apply_rotary_pos_emb_vision(t: torch.Tensor, cos: torch.Tensor,
654-
sin: torch.Tensor) -> torch.Tensor:
655-
rotary_emb_function = dispatch_rotary_emb_function(
656-
default=partial(apply_rotary_emb_torch, is_neox_style=True))
657-
output = rotary_emb_function(t, cos, sin).type_as(t)
658-
return output
659-
660-
661666
# NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm.
662667
Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
663668
Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward
@@ -676,4 +681,3 @@ def _apply_rotary_pos_emb_vision(t: torch.Tensor, cos: torch.Tensor,
676681
Qwen2_5_VisionTransformer.rotary_pos_emb_thw = AscendQwen2_5_VisionTransformer.rotary_pos_emb_thw
677682
Qwen2_5_VisionTransformer.get_rope_by_thw = AscendQwen2_5_VisionTransformer.get_rope_by_thw
678683
Qwen2_5_VisionTransformer.forward = AscendQwen2_5_VisionTransformer.forward
679-
apply_rotary_pos_emb_vision = _apply_rotary_pos_emb_vision

vllm_ascend/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,9 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
679679
AscendRowParallelLinear)
680680
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
681681
from vllm_ascend.ops.rotary_embedding import (
682-
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
683-
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
682+
AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding,
683+
AscendMRotaryEmbedding, AscendRotaryEmbedding,
684+
AscendYaRNRotaryEmbedding)
684685
from vllm_ascend.ops.vocab_parallel_embedding import (
685686
AscendLogitsProcessor, AscendParallelLMHead,
686687
AscendVocabParallelEmbedding)
@@ -706,6 +707,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
706707
"FusedMoE": AscendFusedMoE,
707708
"SharedFusedMoE": AscendSharedFusedMoE,
708709
"MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention,
710+
"ApplyRotaryEmb": AscendApplyRotaryEmb,
709711
}
710712

711713
for name, op_cls in REGISTERED_ASCEND_OPS.items():

0 commit comments

Comments
 (0)