Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def gen_build_info():
"ascend310p3vir02": "_310P",
"ascend310p3vir04": "_310P",
"ascend310p3vir08": "_310P",
"ascend910_9579": "_910_95",
}

assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend."
Expand Down
61 changes: 44 additions & 17 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
prefill_context_parallel_enable,
weak_ref_tensors)

# isort: off
if prefill_context_parallel_enable():
Expand Down Expand Up @@ -1413,12 +1415,20 @@ def forward(
if has_decode:
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
torch_npu._npu_reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)
if get_ascend_device_type() == AscendDeviceType._910_95:
torch_npu.npu_scatter_pa_kv_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)
else:
torch_npu._npu_reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)

if has_prefill:
if self.pcp_size > 1:
Expand All @@ -1432,18 +1442,35 @@ def forward(
key, value = all_kv.split([self.head_size, self.head_size],
dim=-1)

torch_npu._npu_reshape_and_cache(
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
value=value[self.pcp_size *
if get_ascend_device_type() == AscendDeviceType._910_95:
torch_npu.npu_scatter_pa_kv_cache(
key=key[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])
value=value[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded].contiguous(),
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
out=(self.key_cache, self.value_cache))
else:
torch_npu._npu_reshape_and_cache(
key=key[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
value=value[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])

forward_context: ForwardContext = get_forward_context()
if not forward_context.capturing:
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def forward_oot(
# TODO: This judgment will be removed once the mrope precision issue is fixed
if self.mrope_section != [
16, 24, 24
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86:
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86 or \
get_ascend_device_type() == AscendDeviceType._910_95:
return super().forward_oot(positions, query, key)

import torch_npu
Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type,
get_ascend_device_type, is_enable_nz,
prefill_context_parallel_enable,
register_ascend_customop, sleep_mode_enabled,
try_register_lib)
Expand Down Expand Up @@ -344,7 +345,8 @@ def compile_or_warm_up_model(self) -> None:
self.model_runner.capture_model()
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
# may cause performance degradation at runtime.
self._warm_up_atb()
if get_ascend_device_type() != AscendDeviceType._910_95:
self._warm_up_atb()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
NPUPlatform.seed_everything(self.model_config.seed)
Expand Down
Loading