Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
141 changes: 141 additions & 0 deletions tests/e2e/nightly/ops/triton/test_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import gc

import pytest
import torch

from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.bfloat16, torch.float16]
HEAD_SIZES = [64, 128]
ROTARY_DIMS = [32, 64]
NUM_Q_HEADS = [64]
NUM_K_HEADS = [1]
NUM_TOKENS = [1, 4, 8, 16, 1024]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3


def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)


def _rope_pytorch_native(
query, key, cos, sin, rope_dim,
is_neox_style) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
orig_dtype = query.dtype
query_rot = query[..., :rope_dim].to(torch.float32)
key_rot = key[..., :rope_dim].to(torch.float32)
head_size = query.shape[-1]
if rope_dim < head_size:
query_pass = query[..., rope_dim:]
key_pass = key[..., rope_dim:]

if is_neox_style:
cos = cos.repeat(1, 2).unsqueeze(-2).to(torch.float32)
sin = sin.repeat(1, 2).unsqueeze(-2).to(torch.float32)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)

rotate_fn = rotate_neox if is_neox_style else rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if rope_dim < head_size:
query = torch.cat((query_rot.to(orig_dtype), query_pass), dim=-1)
key = torch.cat((key_rot.to(orig_dtype), key_pass), dim=-1)
else:
query = query_rot.to(orig_dtype)
key = key_rot.to(orig_dtype)
return query, key


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_triton_kernel(
is_neox_style: bool,
num_tokens: int,
num_q_heads: int,
num_k_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
if rotary_dim == -1:
rotary_dim = head_size
sin = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
cos = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
q_trt = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_trt = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_gold = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_gold = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_trt.copy_(q_gold)
k_trt.copy_(k_gold)
q_trt, k_trt = rope_forward_triton(q_trt,
k_trt,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
q_gold, k_gold = _rope_pytorch_native(q_gold,
k_gold,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
# Compare the results.
torch.testing.assert_close(q_trt.view(q_gold.size()),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k_trt.view(k_gold.size()),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
61 changes: 39 additions & 22 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import AttentionCGSupport

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
Expand Down Expand Up @@ -490,35 +492,50 @@ def indexer_select(
cos = attn_metadata.cos
sin = attn_metadata.sin

cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)

# q process in new stream
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [b,s,64,128]
q_pe, q_nope = torch.split(
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]

q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]

k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1)
k_pe, k_nope = torch.split(
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]

k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)

k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
k = k.view(-1, 1, self.head_dim)

if HAS_TRITON:
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q,
k,
cos,
sin,
rope_dim=self.qk_rope_head_dim,
is_neox_style=True)
else:
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)

q_pe, q_nope = torch.split(
q,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]

q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]

k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]

k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)

k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]

if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
Expand Down
Loading
Loading