diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py deleted file mode 100644 index 3829d6a7c19..00000000000 --- a/examples/offline_dualbatch_overlap_npu.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -import time - -from vllm import LLM, SamplingParams - -os.environ["VLLM_USE_MODELSCOPE"] = "True" -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -# enable dual-batch overlap for vllm ascend -os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1" - -# Sample prompts. -prompts = ["The president of the United States is"] * 41 -# Create a sampling params object. -sampling_params = SamplingParams(max_tokens=100, temperature=0.0) - - -def main(): - # Create an LLM. - llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic", - enforce_eager=True, - tensor_parallel_size=2, - max_model_len=4096, - trust_remote_code=True, - enable_expert_parallel=True, - additional_config={ - "torchair_graph_config": { - "enabled": False - }, - "ascend_scheduler_config": { - "enabled": True - }, - }) - - # Generate texts from the prompts. The output is a list of RequestOutput - # objects that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - # Add a buffer to wait for profiler in the background process - # (in case MP is on) to finish writing profiling output. - time.sleep(10) - - -if __name__ == "__main__": - main() diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 59353e9e5bb..e7b17a36e69 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -623,11 +623,8 @@ def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache): self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) @patch('vllm_ascend.attention.mla_v1.get_forward_context') - @patch("torch.npu.stream") - @patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context") @patch("torch_npu.npu_fused_infer_attention_score") def test_forward_decode(self, mock_npu_fused_infer_attention_score, - mock_get_multistream_comm_context, mock_npu_stream, mock_get_forward_context): B = 2 N = self.impl.num_kv_heads @@ -651,8 +648,6 @@ def test_forward_decode(self, mock_npu_fused_infer_attention_score, mock_npu_fused_infer_attention_score.return_value = [ torch.randn(B, N, self.impl.kv_lora_rank), None ] - mock_get_multistream_comm_context.return_value = None - mock_get_forward_context.return_value = MagicMock(capturing=False) result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, attn_metadata) @@ -660,18 +655,3 @@ def test_forward_decode(self, mock_npu_fused_infer_attention_score, self.assertEqual(result.shape[0], B) self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[2], HD) - - self.impl.enable_kv_nz = False - attn_metadata.attn_state = None - mock_return_value = MagicMock() - mock_get_multistream_comm_context.return_value = mock_return_value - mock_return_value.before_comm_event = MagicMock() - mock_return_value.comm_stream = MagicMock() - mock_npu_stream.return_value = MagicMock() - - result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, - attn_metadata) - - self.assertEqual(result.shape[0], B) - self.assertEqual(result.shape[1], N) - self.assertEqual(result.shape[2], HD) diff --git a/tests/ut/multistream/test_base.py b/tests/ut/multistream/test_base.py deleted file mode 100644 index 4bdd29b8a5c..00000000000 --- a/tests/ut/multistream/test_base.py +++ /dev/null @@ -1,32 +0,0 @@ -from tests.ut.base import TestBase -from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig, - MSEventKey) - - -class Testbase(TestBase): - - def test_ms_event_key(self): - self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0) - self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1) - self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2) - self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3) - self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4) - self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5) - self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6) - self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7) - self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8) - - def test_ms_attention_metadata_split_config_default(self): - config = MSAttentionMetadataSplitConfig() - self.assertEqual(config.num_micro_batches, 2) - self.assertEqual(config.min_total_tokens_to_split, 256) - self.assertEqual(config.min_prefill_tokens_to_split, 64) - - def test_ms_attention_metadata_split_config_custom(self): - config = MSAttentionMetadataSplitConfig( - num_micro_batches=4, - min_total_tokens_to_split=512, - min_prefill_tokens_to_split=128) - self.assertEqual(config.num_micro_batches, 4) - self.assertEqual(config.min_total_tokens_to_split, 512) - self.assertEqual(config.min_prefill_tokens_to_split, 128) diff --git a/tests/ut/multistream/test_decorator.py b/tests/ut/multistream/test_decorator.py deleted file mode 100644 index bd3da9402e4..00000000000 --- a/tests/ut/multistream/test_decorator.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from pytest_mock import MockFixture - -from tests.ut.base import PytestBase -from vllm_ascend.multistream.decorator import set_multistream_support - - -class Context: - - def __init__(self, attn_metadata=None): - self.attn_metadata = attn_metadata - - -class TestDecorator(PytestBase): - - @pytest.mark.parametrize( - 'layer_context, microbatch_context, expected_metadata', [ - ((-1, None, None), -1, { - "original": True - }), - ((-1, None, None), 0, { - "original": True - }), - ((0, None, None), -1, { - "original": True - }), - ((0, None, [{ - "new": True - }]), 0, { - "new": True - }), - ]) - def test_decorator(self, mocker: MockFixture, layer_context, - microbatch_context, expected_metadata): - - def context_func(): - return Context(attn_metadata={"original": True}) - - mocker.patch( - 'vllm_ascend.multistream.decorator.get_multistream_layer_context', - return_value=layer_context) - mocker.patch( - 'vllm_ascend.multistream.decorator.get_multistream_microbatch_context', - return_value=microbatch_context) - - context = set_multistream_support()(context_func)() - assert context.attn_metadata == expected_metadata diff --git a/tests/ut/multistream/test_layers.py b/tests/ut/multistream/test_layers.py deleted file mode 100644 index cf34c6a09bc..00000000000 --- a/tests/ut/multistream/test_layers.py +++ /dev/null @@ -1,198 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -from unittest.mock import MagicMock, patch - -import pytest -import torch - -from tests.ut.base import PytestBase -from vllm_ascend.multistream.base import MSEventKey -from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, - MultiStreamPreTransformerLayer) -from vllm_ascend.multistream.metadata import MultiStreamMetadata - - -# === fixture: mock tensor input === -@pytest.fixture -def input_tensors(): - return [torch.randn(2, 128), torch.randn(2, 128)] - - -# === mock get_forward_context === -class DummyContext: - - def __init__(self, attn_metadata): - self.attn_metadata = attn_metadata - - -class TestMultiStreamPreTransformerLayer(PytestBase): - - # === test when multistream_metadata is None === - @patch("vllm_ascend.multistream.layers.get_forward_context") - @patch("vllm_ascend.multistream.layers.set_multistream_layer_context") - def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx, - input_tensors): - mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta") - layer = MultiStreamPreTransformerLayer(multistream_metadata=None) - attn_out, input_out = layer.forward(input_tensors) - - assert attn_out == "dummy_meta" - assert input_out == input_tensors - mock_set_ctx.assert_called_once_with(-1, None, None) - - # === test when attn_metadata is None === - @patch("vllm_ascend.multistream.layers.get_forward_context") - @patch("vllm_ascend.multistream.layers.set_multistream_layer_context") - def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx, - input_tensors): - mock_get_ctx.return_value = DummyContext(attn_metadata=None) - dummy_metadata = MagicMock(spec=MultiStreamMetadata) - layer = MultiStreamPreTransformerLayer( - multistream_metadata=dummy_metadata) - - attn_out, input_out = layer.forward(input_tensors) - - assert attn_out is None - assert input_out == input_tensors - mock_set_ctx.assert_called_once_with(-1, None, None) - - # === test when do_ms=False (no split needed) === - @patch("vllm_ascend.multistream.layers.get_forward_context") - @patch("vllm_ascend.multistream.layers.set_multistream_layer_context") - def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors): - dummy_attn = "original_attn" - mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn) - - dummy_metadata = MagicMock(spec=MultiStreamMetadata) - dummy_metadata.split_micro_batch.return_value = (False, "same_attn", - input_tensors, None) - - layer = MultiStreamPreTransformerLayer( - multistream_metadata=dummy_metadata) - - attn_out, input_out = layer.forward(input_tensors) - - assert attn_out == "same_attn" - assert input_out == input_tensors - mock_set_ctx.assert_called_once_with(-1, None, None) - - # === test when do_ms=True (split occurred) === - @patch("vllm_ascend.multistream.layers.get_forward_context") - @patch("vllm_ascend.multistream.layers.set_multistream_layer_context") - def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors): - dummy_attn = "original_attn" - mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn) - - split_inputs = [[t[:1], t[1:]] for t in input_tensors] - - dummy_metadata = MagicMock(spec=MultiStreamMetadata) - dummy_metadata.start_layer = 2 - dummy_metadata.split_micro_batch.return_value = (True, - ["attn1", "attn2"], - split_inputs, None) - - layer = MultiStreamPreTransformerLayer( - multistream_metadata=dummy_metadata) - - attn_out, input_out = layer.forward(input_tensors) - - assert attn_out == ["attn1", "attn2"] - assert input_out == split_inputs - mock_set_ctx.assert_called_once_with(2, dummy_metadata, - ["attn1", "attn2"]) - - -class TestMultiStreamPostTransformerLayer(PytestBase): - - def test_post_forward_metadata_none(self, input_tensors): - layer = MultiStreamPostTransformerLayer(multistream_metadata=None) - output = layer.forward(input_tensors) - assert output == input_tensors - - dummy_metadata = MagicMock(spec=MultiStreamMetadata) - dummy_metadata.ms_config = None - layer = MultiStreamPostTransformerLayer( - multistream_metadata=dummy_metadata) - output = layer.forward(input_tensors) - assert output == input_tensors - - @patch("vllm_ascend.multistream.layers.get_multistream_layer_context") - @patch("vllm_ascend.multistream.layers.reset_multistream_layer_context") - def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx, - input_tensors): - A_instance_of_MultiStreamMetadata = MultiStreamMetadata( - calculate_stream=MagicMock(), - communicate_stream=MagicMock(), - start_layer=0, - end_layer=1, - event_keys=[], - multistream_config=None, - ) - dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata) - dummy_metadata.ms_config.num_micro_batches = 4 - dummy_metadata.end_layer = 10 - - mock_get_ctx.return_value = ( - 5, # layer_index - dummy_metadata, # ms_metadata - "dummy_attn_metadata" # ms_attn_metadata - ) - - dummy_metadata.merge_micro_batches.return_value = "merged_result" - - layer = MultiStreamPostTransformerLayer( - multistream_metadata=dummy_metadata) - output = layer.forward(input_tensors) - - # check wait_event - dummy_metadata.try_wait_event.assert_called_once_with( - 9, # end_layer - 1 - 3, # num_micro_batches - 1 - MSEventKey.FFN_AR_FINISH) - mock_reset_ctx.assert_called_once() - assert output == "merged_result" - - @patch("vllm_ascend.multistream.layers.get_multistream_layer_context") - @patch("vllm_ascend.multistream.layers.reset_multistream_layer_context") - def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx, - mock_get_ctx, input_tensors): - A_instance_of_MultiStreamMetadata = MultiStreamMetadata( - calculate_stream=MagicMock(), - communicate_stream=MagicMock(), - start_layer=0, - end_layer=1, - event_keys=[], - multistream_config=None, - ) - dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata) - dummy_metadata.ms_config.num_micro_batches = 4 - dummy_metadata.end_layer = 10 - - mock_get_ctx.return_value = ( - 3, # layer_index - dummy_metadata, - "dummy_attn_metadata") - - dummy_metadata.merge_micro_batches.return_value = "merged_result" - - layer = MultiStreamPostTransformerLayer( - multistream_metadata=dummy_metadata) - output = layer.forward(input_tensors, wait_layer_index=7) - - dummy_metadata.try_wait_event.assert_called_once_with( - 7, 3, MSEventKey.FFN_AR_FINISH) - mock_reset_ctx.assert_called_once() - assert output == "merged_result" diff --git a/tests/ut/multistream/test_metadata.py b/tests/ut/multistream/test_metadata.py deleted file mode 100644 index 79fd703d14e..00000000000 --- a/tests/ut/multistream/test_metadata.py +++ /dev/null @@ -1,246 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.multistream.base import MSEventKey -from vllm_ascend.multistream.metadata import (MultiStreamConfig, - MultiStreamMetadata, - MultiStreamStepMetadata, - split_micro_batches_tensors) - - -class TestMetaData(TestBase): - - def setUp(self): - self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)] - self.test_tensors = torch.randn(100, 1024) - self.test_tensors_dict = { - 'query': torch.randn(100, 1024), - 'key': torch.randn(100, 1024), - 'value': torch.randn(100, 1024) - } - self.split_index = 50 - - mock_stream = MagicMock(spec=torch.npu.Stream) - event_keys = [MagicMock(spec=MSEventKey)] - multistream_config = MagicMock(spec=MultiStreamConfig) - - self.metadata = MultiStreamMetadata( - calculate_stream=mock_stream, - communicate_stream=mock_stream, - start_layer=1, - end_layer=3, - event_keys=event_keys, - multistream_config=multistream_config) - - def test_split_micro_batches_tensors(self): - test_tensors_list_res = split_micro_batches_tensors( - self.test_tensors_list, self.split_index) - test_tensors_res = split_micro_batches_tensors(self.test_tensors, - self.split_index) - keys = ['query', 'key', 'value'] - test_tensors_dict_res = split_micro_batches_tensors( - self.test_tensors_dict, self.split_index, keys) - for i in range(3): - self.assertEqual(len(test_tensors_list_res[i][0]), - self.split_index) - - self.assertEqual( - len(test_tensors_list_res[i][0]) + - len(test_tensors_list_res[i][1]), 100) - - self.assertEqual(len(test_tensors_res[0]), self.split_index) - self.assertEqual( - len(test_tensors_res[0]) + len(test_tensors_res[1]), 100) - - for key in keys: - self.assertEqual(len(test_tensors_dict_res[0][key]), - self.split_index) - self.assertEqual( - len(test_tensors_dict_res[0][key]) + - len(test_tensors_dict_res[1][key]), 100) - - def test_default_init_multistream_step_metadata(self): - metadata = MultiStreamStepMetadata() - self.assertIsNone(metadata.comm_stream) - self.assertIsNone(metadata.before_comm_event) - self.assertIsNone(metadata.after_comm_event) - - def test_custom_init_multistream_step_metadata(self): - mockStream = MagicMock(spec=torch.npu.Stream) - mockEvent1 = MagicMock(spec=torch.npu.Event) - mockEvent2 = MagicMock(spec=torch.npu.Event) - - metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2) - self.assertEqual(metadata.comm_stream, mockStream) - self.assertEqual(metadata.before_comm_event, mockEvent1) - self.assertEqual(metadata.after_comm_event, mockEvent2) - - def test_default_init_multistream_config(self): - config = MultiStreamConfig() - self.assertEqual(config.min_total_tokens_to_split, 256) - self.assertEqual(config.min_prefill_tokens_to_split, 64) - self.assertEqual(config.num_micro_batches, 2) - self.assertEqual(config.imbalance_ratio, 0.1) - - def test_custom_init_multistream_config(self): - config = MultiStreamConfig(512, 128, 1, 0.2) - self.assertEqual(config.min_total_tokens_to_split, 512) - self.assertEqual(config.min_prefill_tokens_to_split, 128) - self.assertEqual(config.num_micro_batches, 1) - self.assertEqual(config.imbalance_ratio, 0.2) - - def test_init_multistream_metadata(self): - mock_stream = MagicMock(spec=torch.npu.Stream) - - event_keys = [MagicMock()] - multistream_config = MagicMock(spec=MultiStreamConfig) - - metadata = MultiStreamMetadata(calculate_stream=mock_stream, - communicate_stream=mock_stream, - start_layer=1, - end_layer=3, - event_keys=event_keys, - multistream_config=multistream_config) - - self.assertEqual(metadata.calculate_stream, mock_stream) - self.assertEqual(metadata.communicate_stream, mock_stream) - self.assertEqual(metadata.start_layer, 1) - self.assertEqual(metadata.end_layer, 3) - self.assertEqual(metadata.ms_config, multistream_config) - self.assertTrue(metadata.causal_lm) - - def test_build_events(self): - mock_stream = MagicMock(spec=torch.npu.Stream) - mock_event = MagicMock(spec=torch.npu.Event) - with patch('torch.npu.Event', return_value=mock_event): - event_keys = [MagicMock(spec=MSEventKey)] - multistream_config = MultiStreamConfig( - num_micro_batches=2, - min_total_tokens_to_split=256, - min_prefill_tokens_to_split=64) - - metadata = MultiStreamMetadata( - calculate_stream=mock_stream, - communicate_stream=mock_stream, - start_layer=1, - end_layer=3, - event_keys=event_keys, - multistream_config=multistream_config) - - expected_events = { - 0: { - 0: { - event_keys[0]: mock_event - }, - 1: { - event_keys[0]: mock_event - } - }, - 1: { - 0: { - event_keys[0]: mock_event - }, - 1: { - event_keys[0]: mock_event - } - }, - 2: { - 0: { - event_keys[0]: mock_event - }, - 1: { - event_keys[0]: mock_event - } - } - } - self.assertEqual(metadata.ms_events, expected_events) - - def test_build_ms_split_config(self): - mock_stream = MagicMock(spec=torch.npu.Stream) - event_keys = [MagicMock(spec=MSEventKey)] - multistream_config = MagicMock(spec=MultiStreamConfig) - multistream_config.num_micro_batches = 2 - multistream_config.min_total_tokens_to_split = 256 - multistream_config.min_prefill_tokens_to_split = 64 - - metadata = MultiStreamMetadata(calculate_stream=mock_stream, - communicate_stream=mock_stream, - start_layer=1, - end_layer=3, - event_keys=event_keys, - multistream_config=multistream_config) - - self.assertIsNotNone(metadata.ms_split_config) - self.assertEqual(metadata.ms_split_config.num_micro_batches, - multistream_config.num_micro_batches) - self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split, - multistream_config.min_total_tokens_to_split) - self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split, - multistream_config.min_prefill_tokens_to_split) - - def test_try_wait_event(self): - mock_stream = MagicMock(spec=torch.npu.Stream) - mock_event = MagicMock(spec=torch.npu.Event) - event_keys = [MagicMock(spec=MSEventKey)] - multistream_config = MagicMock(spec=MultiStreamConfig) - with patch('torch.npu.Event', return_value=mock_event): - metadata = MultiStreamMetadata( - calculate_stream=mock_stream, - communicate_stream=mock_stream, - start_layer=1, - end_layer=3, - event_keys=event_keys, - multistream_config=multistream_config) - - metadata.try_wait_event(layer_index=1, - micro_batch_index=0, - event_key=event_keys[0]) - mock_event.wait.assert_called_once() - - def test_try_record_event(self): - mock_stream = MagicMock(spec=torch.npu.Stream) - mock_event = MagicMock(spec=torch.npu.Event) - event_keys = [MagicMock(spec=MSEventKey)] - multistream_config = MagicMock(spec=MultiStreamConfig) - with patch('torch.npu.Event', return_value=mock_event): - metadata = MultiStreamMetadata( - calculate_stream=mock_stream, - communicate_stream=mock_stream, - start_layer=1, - end_layer=3, - event_keys=event_keys, - multistream_config=multistream_config) - - metadata.try_record_event(layer_index=1, - micro_batch_index=0, - event_key=event_keys[0]) - mock_event.record.assert_called_once() - - def test_merge_batches_none_input(self): - input_tensors = None - result = self.metadata.merge_micro_batches(input_tensors) - self.assertIsNone(result) - - def test_merge_batches_single_tensor_input(self): - input_tensors = [torch.tensor([1, 2, 3])] - result = self.metadata.merge_micro_batches(input_tensors) - self.assertEqual(len(result), 1) - self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3]))) - - def test_merge_batches_list_of_tensors_input(self): - input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])] - result = self.metadata.merge_micro_batches(input_tensors) - self.assertEqual(len(result), 2) - self.assertEqual(result, input_tensors) - - def test_merge_batches_nested_list_input(self): - input_tensors = [[torch.tensor([1, 2]), - torch.tensor([3, 4])], - [torch.tensor([5, 6]), - torch.tensor([7, 8])]] - result = self.metadata.merge_micro_batches(input_tensors) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4]))) - self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8]))) diff --git a/tests/ut/multistream/test_ms_split.py b/tests/ut/multistream/test_ms_split.py deleted file mode 100644 index e76321a6e55..00000000000 --- a/tests/ut/multistream/test_ms_split.py +++ /dev/null @@ -1,147 +0,0 @@ -from unittest.mock import MagicMock - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig -from vllm_ascend.multistream.ms_split import (compute_split_seq_index, - model_input_split_v1_mla_attn, - split_attn_int_type, - split_attn_tensor_type) - - -class TestMsSplit(TestBase): - - def test_decode_only(self): - result = compute_split_seq_index( - query_lens=None, - attn_state=AscendAttentionState.DecodeOnly, - num_tokens=10) - self.assertEqual(result, [5, 5]) - - def test_perfect_balance(self): - query_lens = [2, 3, 5] - result = compute_split_seq_index( - query_lens=query_lens, - attn_state=AscendAttentionState.PrefillNoCache, - num_tokens=10) - self.assertEqual(result, [5, 2]) - - def test_imbalance(self): - query_lens = [1, 2, 3, 4] - result = compute_split_seq_index( - query_lens=query_lens, - attn_state=AscendAttentionState.PrefillNoCache, - num_tokens=10) - self.assertEqual(result, [0, 0]) - - def test_query_lens_none(self): - with self.assertRaises(AssertionError): - compute_split_seq_index( - query_lens=None, - attn_state=AscendAttentionState.PrefillNoCache, - num_tokens=10) - - def test_empty_query_lens(self): - query_lens: list[int] = [] - result = compute_split_seq_index( - query_lens=query_lens, - attn_state=AscendAttentionState.PrefillNoCache, - num_tokens=10) - self.assertEqual(result, [0, 0]) - - def test_single_query_len(self): - query_lens = [10] - result = compute_split_seq_index( - query_lens=query_lens, - attn_state=AscendAttentionState.PrefillNoCache, - num_tokens=10) - self.assertEqual(result, [0, 0]) - - def test_split_attn_tensor_type_middle(self): - input_tensor = torch.tensor([1, 2, 3, 4, 5]) - index = 3 - expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])] - result = split_attn_tensor_type(input_tensor, index) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], expected_result[0])) - self.assertTrue(torch.equal(result[1], expected_result[1])) - - def test_split_attn_tensor_type_start(self): - input_tensor = torch.tensor([1, 2, 3, 4, 5]) - index = 0 - expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])] - result = split_attn_tensor_type(input_tensor, index) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], expected_result[0])) - self.assertTrue(torch.equal(result[1], expected_result[1])) - - def test_split_attn_tensor_type_end(self): - input_tensor = torch.tensor([1, 2, 3, 4, 5]) - index = 5 - expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])] - result = split_attn_tensor_type(input_tensor, index) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], expected_result[0])) - self.assertTrue(torch.equal(result[1], expected_result[1])) - - def test_split_attn_tensor_type_empty_tensor(self): - input_tensor = torch.tensor([]) - index = 0 - expected_result = [torch.tensor([]), torch.tensor([])] - result = split_attn_tensor_type(input_tensor, index) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], expected_result[0])) - self.assertTrue(torch.equal(result[1], expected_result[1])) - - def test_split_attn_int_type_index_greater_than_var(self): - var = 5 - index = 10 - expected_result = [5, 0] - result = split_attn_int_type(var, index) - self.assertEqual(result, expected_result) - - def test_split_attn_int_type_index_equal_to_var(self): - var = 5 - index = 5 - expected_result = [5, 0] - result = split_attn_int_type(var, index) - self.assertEqual(result, expected_result) - - def test_split_attn_int_type_index_less_than_var(self): - var = 10 - index = 5 - expected_result = [5, 5] - result = split_attn_int_type(var, index) - self.assertEqual(result, expected_result) - - def test_split_attn_int_type_index_zero(self): - var = 10 - index = 0 - expected_result = [0, 10] - result = split_attn_int_type(var, index) - self.assertEqual(result, expected_result) - - def test_split_attn_int_type_var_zero(self): - var = 0 - index = 5 - expected_result = [0, 0] - result = split_attn_int_type(var, index) - self.assertEqual(result, expected_result) - - def test_split_attn_int_type_both_zero(self): - var = 0 - index = 0 - expected_result = [0, 0] - result = split_attn_int_type(var, index) - self.assertEqual(result, expected_result) - - def test_split_v1_mla_attn_input_none(self): - attn_metadata = None - ascendMLAPrefillMetadata = MagicMock() - ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1) - result = model_input_split_v1_mla_attn(attn_metadata, - ascendMLAPrefillMetadata, - ms_split_config) - self.assertEqual(result, [None]) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 93cac288a6a..e874db250bc 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -213,9 +213,6 @@ class AscendMetadata: # (num_tokens,) slot_mapping: torch.Tensor = None - # *************************** Other Properties *************************** # - enable_dbo_across_dp: bool = False - prefill: Optional[AscendMetadataForPrefill] = None decode_meta: Optional[AscendMetadataForDecode] = None @@ -374,7 +371,6 @@ def build( slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, num_prefills=num_prefills, num_decodes=num_decodes, prefill=prefill_metadata, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 01fe9a369d2..c4503f599d5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -36,9 +36,6 @@ trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import get_graph_params -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig -from vllm_ascend.multistream.context import get_multistream_comm_context -from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, @@ -184,7 +181,6 @@ class AscendMLAMetadata: decode: Optional[AscendMLADecodeMetadata] = None prefill: Optional[AscendMLAPrefillMetadata] = None - enable_dbo_across_dp: bool = False def __post_init__(self): pass @@ -195,17 +191,6 @@ def __post_init__(self): # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") - def split_metadata_for_multistream( - self, - ms_split_config: MSAttentionMetadataSplitConfig, - ) -> list["AscendMLAMetadata"]: - """Split metadata for multi-stream with AscendMLAMetadata""" - return model_input_split_v1_mla_attn( - ms_split_config=ms_split_config, - attn_metadata=self, - _metadata_cls=AscendMLAMetadata, - ) - M = TypeVar("M", bound=AscendMLAMetadata) @@ -538,7 +523,6 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) def build_for_graph_capture( @@ -1158,14 +1142,8 @@ def _forward_decode( else: attn_output, _ = torch_npu.npu_fused_infer_attention_score( q_nope, k_nope, k_nope, **common_kwargs) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - return self._v_up_proj(attn_output) - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - return self._v_up_proj(attn_output) + + return self._v_up_proj(attn_output) def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): bsz = attn_metadata.num_decode_tokens @@ -1423,13 +1401,8 @@ def forward( decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe, decode_preprocess_res.k_nope, decode_preprocess_res.k_pe, kv_cache[0].shape[1], attn_metadata) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is not None: - with torch.npu.stream(current_ms_metadata.comm_stream): - o_proj_input[:num_decode_tokens] = output_decode - current_ms_metadata.after_comm_event.record() - else: - o_proj_input[:num_decode_tokens] = output_decode + + o_proj_input[:num_decode_tokens] = output_decode if prefill_preprocess_res is not None: # FIX: aicore move should be also placed on the comm stream in dbo, @@ -1445,36 +1418,19 @@ def forward( prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, prefill_preprocess_res.value, kv_cache, attn_metadata) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is not None: - with torch.npu.stream(current_ms_metadata.comm_stream): - o_proj_input[num_decode_tokens:] = output_prefill - current_ms_metadata.after_comm_event.record() - else: - o_proj_input[ - num_decode_tokens:num_actual_tokens] = output_prefill + + o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill # O proj - current_ms_metadata = get_multistream_comm_context() MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 - if current_ms_metadata is None: - maybe_npu_prefetch(inputs=self.o_proj.weight, - dependency=o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) + + output[...] = self.o_proj(o_proj_input, + is_prefill=prefill_preprocess_res + is not None)[0] - output[...] = self.o_proj(o_proj_input, - is_prefill=prefill_preprocess_res - is not None)[0] - else: - with torch.npu.stream(current_ms_metadata.comm_stream): - maybe_npu_prefetch(inputs=self.o_proj.weight, - dependency=o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) - output[...] = self.o_proj(o_proj_input, - is_prefill=prefill_preprocess_res - is not None)[0] - current_ms_metadata.after_comm_event.record() del o_proj_input has_prefill = attn_metadata.num_prefills > 0 @@ -1719,18 +1675,9 @@ def _forward_decode_pcp_dcp( attn_out_g, attn_lse_g, attn_out_l, attn_lse_l, seq_mask_pcp[:, i]) attn_output = attn_out_g - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - return self._v_up_proj(attn_output) - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - return self._v_up_proj(attn_output) - - -# TODO use update op to replace this + return self._v_up_proj(attn_output) + # TODO use update op to replace this def _update_out_and_lse( self, out: torch.Tensor, diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index e7461dc6fe1..b6514028ced 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -17,11 +17,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.mla_v1 import AscendMLAMetadata from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig -from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -138,7 +135,6 @@ class AscendSFAMetadata: decode: Optional[AscendSFADecodeMetadata] = None prefill: Optional[AscendSFAPrefillMetadata] = None - enable_dbo_across_dp: bool = False def __post_init__(self): pass @@ -149,17 +145,6 @@ def __post_init__(self): # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") - def split_metadata_for_multistream( - self, - ms_split_config: MSAttentionMetadataSplitConfig, - ) -> list["AscendSFAMetadata"]: - """Split metadata for multi-stream with AscendSFAMetadata""" - return model_input_split_v1_mla_attn( - ms_split_config=ms_split_config, - attn_metadata=self, - _metadata_cls=AscendMLAMetadata, - ) - M = TypeVar("M", bound=AscendSFAMetadata) @@ -434,7 +419,6 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 27a371598d2..ede83f74a54 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -91,8 +91,6 @@ class AscendCommonAttentionMetadata: attn_state: Any = None - enable_dbo_across_dp: bool = False - is_only_prefill: bool = False graph_pad_size: int = -1 diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 4cd430c889f..8f9e1d98996 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -82,9 +82,6 @@ "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0')) ), - # Whether to enable DBO feature for deepseek model. - "VLLM_ASCEND_ENABLE_DBO": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))), # Whether to enable the model execute time observe profile. Disable it when # running vllm ascend in production environment. "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": diff --git a/vllm_ascend/multistream/__init__.py b/vllm_ascend/multistream/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py deleted file mode 100644 index fba58b460ea..00000000000 --- a/vllm_ascend/multistream/base.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - - -class MSEventKey(Enum): - ATTN_COM_FINISH = 0 - ATTN_AR_FINISH = 1 - FFN_COM_FINISH = 2 - FFN_AR_FINISH = 3 - # events for MOE dispatch and combine - MOE_BEFORE_COMM = 4 - MOE_AFTER_COMM = 5 - # events for shared expert - MOE_SE_COMM_FINISH = 6 - MOE_SE_COMP_FINISH = 7 - MOE_GATE_FINISH = 8 - - -@dataclass -class MSAttentionMetadataSplitConfig: - """ - micro batch split config for split attention metadata - """ - # micro batch num - num_micro_batches: int = 2 - # split micro batches only when total tokens >= min_total_tokens_to_split - min_total_tokens_to_split: int = 256 - # split micro batches only when prefill tokens >= min_prefill_tokens_to_split - min_prefill_tokens_to_split: int = 64 diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py deleted file mode 100644 index a1684f2f556..00000000000 --- a/vllm_ascend/multistream/context.py +++ /dev/null @@ -1,67 +0,0 @@ -from contextlib import contextmanager -from typing import Any - -_ms_comm_context: Any = None -_cur_micro_batch_num: int = -1 -_ms_layer_index_context: int = -1 -_ms_metadata_context: Any = None -_ms_attn_metadata_context: Any = None - - -def set_multistream_layer_context(start_layer: int, ms_metadata: Any, - attn_metadata: Any): - """ - set multistream layer context before transformer layers - """ - global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context - _ms_layer_index_context = start_layer - _ms_metadata_context = ms_metadata - _ms_attn_metadata_context = attn_metadata - - -def reset_multistream_layer_context(): - """ - reset multistream layer context - """ - global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context - _ms_layer_index_context = -1 - _ms_metadata_context = None - _ms_attn_metadata_context = None - - -def get_multistream_layer_context(): - """ - get multistream layer context - """ - return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context - - -def advance_step_multistream_layer_context(): - """ - advance multistream layer index context - """ - global _ms_layer_index_context - _ms_layer_index_context += 1 - - -def get_multistream_comm_context() -> Any: - """Get the current comm forward context.""" - return _ms_comm_context - - -def get_multistream_microbatch_context() -> int: - return _cur_micro_batch_num - - -@contextmanager -def set_multistream_context(context: Any, micro_batch_num: int): - """A context manager that stores the current comm forward context, - can be attention metadata, etc.""" - global _ms_comm_context, _cur_micro_batch_num - _ms_comm_context = context - _cur_micro_batch_num = micro_batch_num - try: - yield - finally: - _ms_comm_context = None - _cur_micro_batch_num = -1 diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py deleted file mode 100644 index 5b573df26cd..00000000000 --- a/vllm_ascend/multistream/decorator.py +++ /dev/null @@ -1,22 +0,0 @@ -from .context import (get_multistream_layer_context, - get_multistream_microbatch_context) - - -# vllm v1 use get_forward_context to get the attn_metadata, -# we can use this decorator to update the attn metadata -def set_multistream_support(): - - def decorator(func): - - def wrapper(): - context = func() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( - ) - micro_batch_num = get_multistream_microbatch_context() - if layer_index != -1 and micro_batch_num != -1: - context.attn_metadata = attn_metadata[micro_batch_num] - return context - - return wrapper - - return decorator diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py deleted file mode 100644 index c5273bce73b..00000000000 --- a/vllm_ascend/multistream/layers.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List, Optional, Tuple, Union - -import torch -from vllm.forward_context import get_forward_context - -from .base import MSEventKey -from .context import (get_multistream_layer_context, - reset_multistream_layer_context, - set_multistream_layer_context) -from .metadata import MultiStreamMetadata - - -class MultiStreamPreTransformerLayer(torch.nn.Module): - - def __init__(self, multistream_metadata: MultiStreamMetadata): - super().__init__() - self.multistream_metadata = multistream_metadata - - def forward( - self, - intput_tensors: List[torch.Tensor], - ): - attn_metadata = get_forward_context().attn_metadata - if self.multistream_metadata is None or attn_metadata is None: - set_multistream_layer_context(-1, None, None) - return attn_metadata, intput_tensors - # TODO add attn_metadata management - do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch( - attn_metadata, intput_tensors) - if do_ms: - set_multistream_layer_context( - self.multistream_metadata.start_layer, - self.multistream_metadata, attn_metadata) - else: - set_multistream_layer_context(-1, None, None) - return attn_metadata, intput_tensors - - -class MultiStreamPostTransformerLayer(torch.nn.Module): - - def __init__(self, multistream_metadata: MultiStreamMetadata): - super().__init__() - self.multistream_metadata = multistream_metadata - - def forward(self, - input_tensors: Union[List[Tuple[torch.Tensor]], - List[torch.Tensor], - List[List[torch.Tensor]]], - wait_layer_index: Optional[int] = None): - if self.multistream_metadata is None or self.multistream_metadata.ms_config is None: - return input_tensors - layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context( - ) - if layer_index >= 0: - true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index - self.multistream_metadata.try_wait_event( - true_wait_layer, - self.multistream_metadata.ms_config.num_micro_batches - 1, - MSEventKey.FFN_AR_FINISH) - reset_multistream_layer_context() - return self.multistream_metadata.merge_micro_batches(input_tensors) diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py deleted file mode 100644 index b521d3f85f0..00000000000 --- a/vllm_ascend/multistream/metadata.py +++ /dev/null @@ -1,182 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union - -import torch -from vllm.sequence import IntermediateTensors - -from vllm_ascend.attention.mla_v1 import AscendMLAMetadata - -from .base import MSAttentionMetadataSplitConfig, MSEventKey - - -def split_micro_batches_tensors(input_tensors, - split_index: int, - keys: Optional[List[str]] = None): - if isinstance(input_tensors, list): - micro_batches = [] - for tensor in input_tensors: - if tensor is None: - micro_batches.append([None, None]) - else: - micro_batches.append( - [tensor[:split_index], tensor[split_index:]]) - return micro_batches - elif isinstance(input_tensors, torch.Tensor): - return [input_tensors[:split_index], input_tensors[split_index:]] - elif input_tensors is None: - return [None, None] - elif isinstance(input_tensors, Dict): - assert keys is not None - micro_batches_pre = {} - for key in keys: - micro_batches_pre[key] = input_tensors[key][:split_index] - micro_batches_post = {} - for key in keys: - micro_batches_post[key] = input_tensors[key][split_index:] - return [micro_batches_pre, micro_batches_post] - else: - raise NotImplementedError - - -@dataclass -class MultiStreamStepMetadata: - comm_stream: torch.npu.Stream = None - before_comm_event: torch.npu.Event = None - after_comm_event: torch.npu.Event = None - - -@dataclass -class MultiStreamConfig: - """Controls the behavior of multi-stream models.""" - min_total_tokens_to_split: int = 256 - min_prefill_tokens_to_split: int = 64 - num_micro_batches: int = 2 - imbalance_ratio: float = 0.1 - - -class MultiStreamMetadata: - # direct stream - calculate_stream = None - # delay stream - communicate_stream = None - # events - ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {} - # multi-stream-flag - enable_multi_stream: bool = False - - def __init__( - self, - calculate_stream: torch.npu.Stream, - communicate_stream: torch.npu.Stream, - start_layer: int, - end_layer: int, - event_keys: List[MSEventKey], - multistream_config: Optional[MultiStreamConfig], - causal_lm: bool = True, - ): - self.calculate_stream = calculate_stream - self.communicate_stream = communicate_stream - self.start_layer = start_layer - self.end_layer = end_layer - self.ms_config = multistream_config - self.causal_lm = causal_lm - self._build_events(event_keys) - self._build_ms_split_config() - - def _build_events(self, event_keys): - if self.ms_config is not None: - for i in range(self.start_layer - 1, self.end_layer): - self.ms_events[i] = {} - for j in range(self.ms_config.num_micro_batches): - self.ms_events[i][j] = {} - for key in event_keys: - self.ms_events[i][j][key] = torch.npu.Event() - - def _build_ms_split_config(self): - if self.ms_config is not None: - self.ms_split_config = MSAttentionMetadataSplitConfig( - num_micro_batches=self.ms_config.num_micro_batches, - min_total_tokens_to_split=self.ms_config. - min_total_tokens_to_split, - min_prefill_tokens_to_split=self.ms_config. - min_prefill_tokens_to_split, - ) - - def try_wait_event(self, layer_index: int, micro_batch_index: int, - event_key: MSEventKey): - self.ms_events[layer_index][micro_batch_index][event_key].wait() - - def try_record_event(self, layer_index: int, micro_batch_index: int, - event_key: MSEventKey): - self.ms_events[layer_index][micro_batch_index][event_key].record() - - def split_micro_batch( - self, - attn_metadata: "AscendMLAMetadata", - intput_tensors: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - intermediate_tensors_keys: Optional[List[str]] = None, - ) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[ - List[torch.Tensor], List[List[torch.Tensor]]], Union[ - IntermediateTensors, List[IntermediateTensors]]]: - attn_metadata_list = attn_metadata.split_metadata_for_multistream( - self.ms_split_config) - if len(attn_metadata_list) == 1: - return False, attn_metadata_list[ - 0], intput_tensors, intermediate_tensors - split_index = attn_metadata_list[0].slot_mapping.shape[0] - input_tensors = split_micro_batches_tensors(intput_tensors, - split_index) - if intermediate_tensors is not None: - inter_tensors_list = split_micro_batches_tensors( - intermediate_tensors.tensors, split_index, - intermediate_tensors_keys) - intermediate_tensors = [ - IntermediateTensors(inter_tensors) - for inter_tensors in inter_tensors_list - ] - return True, attn_metadata_list, input_tensors, intermediate_tensors - - def merge_micro_batches( - self, input_tensors: Union[List[torch.Tensor], - List[List[torch.Tensor]]] - ) -> List[torch.Tensor]: - if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): - return input_tensors - batch: List[Optional[torch.Tensor]] = [] - for tensors in input_tensors: - if tensors is None or tensors[0] is None: - batch.append(None) - else: - batch.append(torch.cat(tensors, dim=0)) - return batch - - -def make_multistream_metadata_ds( - start_layer: int, - end_layer: int, - causal_lm: bool = True, - multistream_config: Optional[MultiStreamConfig] = None, -): - if multistream_config is None: - return None - event_keylist = [ - MSEventKey.ATTN_COM_FINISH, - MSEventKey.ATTN_AR_FINISH, - MSEventKey.FFN_COM_FINISH, - MSEventKey.FFN_AR_FINISH, - MSEventKey.MOE_BEFORE_COMM, - MSEventKey.MOE_AFTER_COMM, - MSEventKey.MOE_SE_COMM_FINISH, - MSEventKey.MOE_SE_COMP_FINISH, - MSEventKey.MOE_GATE_FINISH, - ] - return MultiStreamMetadata( - calculate_stream=torch.npu.current_stream(), - communicate_stream=torch.npu.Stream(), - start_layer=start_layer, - end_layer=end_layer, - multistream_config=multistream_config, - event_keys=event_keylist, - causal_lm=causal_lm, - ) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py deleted file mode 100644 index b7b356bed93..00000000000 --- a/vllm_ascend/multistream/ms_split.py +++ /dev/null @@ -1,247 +0,0 @@ -from copy import deepcopy -from typing import Any, List, Optional - -import numpy as np -import torch - -from vllm_ascend.attention.attention_v1 import AscendAttentionState - -from .base import MSAttentionMetadataSplitConfig - - -def compute_split_seq_index( - query_lens: Optional[list[int]], - attn_state: AscendAttentionState, - num_tokens: int, - imbalance_ratio: float = 0.1, -) -> list[int]: - if attn_state != AscendAttentionState.DecodeOnly: - assert query_lens is not None - total_tokens = sum(query_lens) - # the first index in last split - tokens, split_index = 0, 0 - for value in query_lens: - tokens += value - split_index += 1 - if tokens >= total_tokens // 2: - # check the current split index - if abs(tokens - - total_tokens // 2) < total_tokens * imbalance_ratio: - return [tokens, split_index] - # check the previous split index - elif abs(tokens - total_tokens // 2 - - value) < total_tokens * imbalance_ratio: - return [tokens - value, split_index - 1] - # fail to split if it is imbalanced - # TODO: split tokens in seq - else: - return [0, 0] - else: - tokens = num_tokens // 2 - return [tokens, tokens] - return [0, 0] - - -def split_attn_tensor_type( - input_tensor: torch.Tensor, - index: int, -) -> List[torch.Tensor]: - return [input_tensor[:index], input_tensor[index:]] - - -def split_attn_int_type( - var: int, - index: int, -) -> List[torch.Tensor]: - return [min(var, index), max(var - index, 0)] - - -def model_input_split_v1_mla_attn( - attn_metadata, - _metadata_cls, - ms_split_config: MSAttentionMetadataSplitConfig, -) -> List[Any]: - assert 0 < ms_split_config.num_micro_batches < 3 - if attn_metadata is None: - return [attn_metadata] - [token_index, - seq_index] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return [attn_metadata] - - query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ), - dtype=int) - np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) - if attn_metadata.num_prefills > 0: - prefill_query_start_loc = np.zeros( - shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int) - np.cumsum(attn_metadata.prefill.query_lens, - out=prefill_query_start_loc[1:]) - - # split attn metadata - [slot_mapping_pre, - slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, - token_index) - [num_decodes_pre, - num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, - seq_index) - [num_decode_tokens_pre, num_decode_tokens_post - ] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) - [num_prefills_pre, num_prefills_post - ] = split_attn_int_type(attn_metadata.num_prefills, - max(0, seq_index - attn_metadata.num_decodes)) - seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens - [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) - - query_start_loc_pre = query_start_loc_post = None - if attn_metadata.query_start_loc is not None: - query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] - query_start_loc_post = deepcopy( - attn_metadata.query_start_loc[seq_index:] - ) - attn_metadata.query_start_loc[seq_index] - [block_table_pre, - block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, - seq_index) - assert attn_metadata.attn_mask is not None - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: - # the attn_mla kernel in torch npu only accept 128*128 attn mask - attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = attn_metadata.attn_state - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - # should be none in decode only state - attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly - else: - # chunked prefill - if num_prefills_pre > 0: - attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( - seq_lens_pre)].contiguous() - attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_post = attn_metadata.attn_mask[ - token_index:, :max(seq_lens_post)].contiguous() - else: - attn_state_pre = AscendAttentionState.DecodeOnly - attn_mask_pre = None - attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_post = attn_metadata.attn_mask[ - token_index:, :max(seq_lens_post)].contiguous() - from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, - AscendMLAPrefillMetadata) - if num_prefills_pre > 0: - # split metadata.prefill - [input_positions_pre, input_positions_post] = split_attn_tensor_type( - attn_metadata.prefill.input_positions, - token_index - attn_metadata.num_decode_tokens) - [block_tables_pre, block_tables_post - ] = split_attn_tensor_type(attn_metadata.prefill.block_table, - seq_index - attn_metadata.num_decodes) - [prefill_query_lens_pre, prefill_query_lens_post - ] = split_attn_tensor_type(attn_metadata.prefill.query_lens, - seq_index - attn_metadata.num_decodes) - prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[: - seq_index - + - 1 - - attn_metadata - . - num_decodes] - prefill_query_start_loc_post = deepcopy( - attn_metadata.prefill.query_start_loc[seq_index - - attn_metadata.num_decodes:] - ) - attn_metadata.prefill.query_start_loc[seq_index - - attn_metadata.num_decodes] - context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] - context_len_post = seq_lens_post - prefill_max_query_len_pre = max(prefill_query_lens_pre) - prefill_max_query_len_post = max(prefill_query_lens_post) - prefill_pre = AscendMLAPrefillMetadata( - attn_mask=attn_mask_pre, - query_lens=prefill_query_lens_pre, - seq_lens=seq_lens_pre, - query_start_loc=prefill_query_start_loc_pre, - input_positions=input_positions_pre, - context_lens=context_len_pre, - block_table=block_tables_pre, - max_query_len=prefill_max_query_len_pre, - max_seq_lens=context_len_pre.max().item(), - ) - prefill_post = AscendMLAPrefillMetadata( - attn_mask=attn_mask_post, - query_lens=prefill_query_lens_post, - seq_lens=seq_lens_post, - query_start_loc=prefill_query_start_loc_post, - input_positions=input_positions_post, - context_lens=context_len_post, - block_table=block_tables_post, - max_query_len=prefill_max_query_len_post, - max_seq_lens=context_len_post.max().item(), - ) - decode_pre = attn_metadata.decode - decode_post = None - else: - # prefill is None, split metadata.decode - [input_positions_pre, input_positions_post - ] = split_attn_tensor_type(attn_metadata.decode.input_positions, - token_index) - [block_tables_pre, block_tables_post - ] = split_attn_tensor_type(attn_metadata.decode.block_table, - seq_index) - [decode_seq_lens_pre, - decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) - decode_pre = AscendMLADecodeMetadata( - input_positions=input_positions_pre, - block_table=block_tables_pre, - seq_lens=decode_seq_lens_pre, - max_seq_lens=max(decode_seq_lens_pre), - seq_lens_list=decode_seq_lens_pre.tolist(), - ) - decode_post = AscendMLADecodeMetadata( - input_positions=input_positions_post, - block_table=block_tables_post, - seq_lens=decode_seq_lens_post, - max_seq_lens=max(decode_seq_lens_post), - seq_lens_list=decode_seq_lens_post.tolist(), - ) - prefill_pre = None - prefill_post = attn_metadata.prefill - # construct metadata - from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata - attention_metadata_pre = _metadata_cls( - num_actual_tokens=token_index, - num_input_tokens=token_index, - head_dim=attn_metadata.head_dim, - slot_mapping=slot_mapping_pre, - seq_lens=seq_lens_pre, - query_start_loc=query_start_loc_pre, - block_tables=block_table_pre, - num_decodes=num_decodes_pre, - num_prefills=num_prefills_pre, - num_decode_tokens=num_decode_tokens_pre, - attn_state=attn_state_pre, - attn_mask=attn_mask_pre, - prefill=prefill_pre, - decode=decode_pre, - enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, - ) - attention_metadata_post = _metadata_cls( - num_actual_tokens=attn_metadata.num_actual_tokens - token_index, - num_input_tokens=attn_metadata.num_input_tokens - token_index, - head_dim=attn_metadata.head_dim, - slot_mapping=slot_mapping_post, - seq_lens=seq_lens_post, - query_start_loc=query_start_loc_post, - block_tables=block_table_post, - num_decodes=num_decodes_post, - num_prefills=num_prefills_post, - num_decode_tokens=num_decode_tokens_post, - attn_mask=attn_mask_post, - attn_state=attn_state_post, - prefill=prefill_post, - decode=decode_post, - enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, - ) - return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index ef3925a8238..3bf4192be5c 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -116,10 +116,11 @@ def dummy_run(self, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None) -> None: if not self.torchair_graph_enabled: - # TODO: adapt enable_dbo later - (num_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._sync_metadata_across_dp(num_tokens, - with_prefill, False) + ( + num_tokens, + num_tokens_across_dp, + with_prefill, + ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) @@ -423,10 +424,9 @@ def _propose( if not self.torchair_graph_enabled: # torch mode need to update num_tokens_across_dp - # TODO: adapt enable_dbo later - (num_input_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._sync_metadata_across_dp( - num_input_tokens, self.runner.with_prefill, False) + (num_input_tokens, num_tokens_across_dp, + with_prefill) = self.runner._sync_metadata_across_dp( + num_input_tokens, self.runner.with_prefill) else: # torchair mode can reuse self.runner.num_tokens_across_dp num_tokens_across_dp = self.runner.num_tokens_across_dp diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 3d3177a0dca..8ea1636ac1c 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -264,8 +264,7 @@ def build( max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, - attn_state=attn_state, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) + attn_state=attn_state) return attn_metadata diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 005c81f9fcf..ce539b7d682 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -20,9 +20,6 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig -from vllm_ascend.multistream.context import get_multistream_comm_context -from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, npu_stream_switch, npu_wait_tensor) @@ -141,7 +138,6 @@ class AscendMLATorchairMetadata: decode: Optional[AscendMLATorchairDecodeMetadata] = None prefill: Optional[AscendMLATorchairPrefillMetadata] = None - enable_dbo_across_dp: bool = False def __post_init__(self): pass @@ -152,17 +148,6 @@ def __post_init__(self): # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") - def split_metadata_for_multistream( - self, - ms_split_config: MSAttentionMetadataSplitConfig, - ) -> list["AscendMLATorchairMetadata"]: - """Split metadata for multi-stream with AscendMLATorchairMetadata""" - return model_input_split_v1_mla_attn( - ms_split_config=ms_split_config, - attn_metadata=self, - _metadata_cls=AscendMLATorchairMetadata, - ) - M = TypeVar("M", bound=AscendMLATorchairMetadata) @@ -576,7 +561,6 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, @@ -1072,15 +1056,8 @@ def _forward_decode( context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - return self._v_up_proj_and_o_proj(attn_output, - enable_multistream_mla) - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - return self._v_up_proj_and_o_proj(attn_output) + + return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla) def forward( self, @@ -1248,14 +1225,7 @@ def forward( prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is not None: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - o_proj_input[num_decode_tokens:] = output_prefill - else: - o_proj_input[num_decode_tokens:] = output_prefill + o_proj_input[num_decode_tokens:] = output_prefill if has_decode: if self.running_in_graph: @@ -1269,35 +1239,19 @@ def forward( decode_k_nope, decode_k_pe, kv_cache, attn_metadata) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is not None: - with torch.npu.stream(current_ms_metadata.comm_stream): - o_proj_input[:num_decode_tokens] = output_decode - else: - o_proj_input[:num_decode_tokens] = output_decode + o_proj_input[:num_decode_tokens] = output_decode - current_ms_metadata = get_multistream_comm_context() MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB - if current_ms_metadata is None: - maybe_npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) - - output[...] = self.o_proj( - o_proj_input, - is_prefill=True, - is_force_scatter=self.enable_shared_expert_dp)[0] - else: - with torch.npu.stream(current_ms_metadata.comm_stream): - maybe_npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) - output[...] = self.o_proj( - o_proj_input, - is_prefill=True, - is_force_scatter=self.enable_shared_expert_dp)[0] - current_ms_metadata.after_comm_event.record() + + maybe_npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + del o_proj_input return output_padded diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index db30831f0ef..56befcc1124 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -110,30 +110,28 @@ def _init_mc2_tokens_capacity(self): self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size def _sync_metadata_across_dp( - self, num_tokens: int, with_prefill: bool, enable_dbo: bool - ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + self, num_tokens: int, + with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]: """Override from NPUModelRunner to pad num_tokens""" if self.enable_shared_expert_dp: # Padding is not required for shared_expert_dp cases in eager mode. - return num_tokens, None, with_prefill, enable_dbo + return num_tokens, None, with_prefill if self.dp_size == 1: if not with_prefill: maybe_padded_num_tokens = self.select_torchair_padded_batch_size( num_tokens) - return maybe_padded_num_tokens, None, with_prefill, enable_dbo - return num_tokens, None, with_prefill, enable_dbo + return maybe_padded_num_tokens, None, with_prefill + return num_tokens, None, with_prefill - num_tokens_across_dp = torch.zeros(self.dp_size + 2, + num_tokens_across_dp = torch.zeros(self.dp_size + 1, dtype=torch.int32, device="npu") num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) + num_tokens_across_dp[-1] = int(with_prefill) dist.all_reduce(num_tokens_across_dp, group=get_dp_group().device_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] + with_prefill = bool(num_tokens_across_dp[-1]) + num_tokens_across_dp = num_tokens_across_dp[:-1] if not with_prefill: max_num_token = num_tokens_across_dp.max().item() @@ -146,7 +144,7 @@ def _sync_metadata_across_dp( else: maybe_padded_num_tokens = num_tokens - return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill def _build_dummy_attn_metadata( self, diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index 0652281f28d..1390aee33d2 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -21,8 +21,6 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig -from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.utils import is_enable_nz from vllm_ascend.worker.npu_input_batch import InputBatch @@ -141,7 +139,6 @@ class AscendSFATorchairMetadata: decode: Optional[AscendSFATorchairDecodeMetadata] = None prefill: Optional[AscendSFATorchairPrefillMetadata] = None - enable_dbo_across_dp: bool = False is_prefill: bool = False is_decode: bool = False @@ -154,17 +151,6 @@ def __post_init__(self): # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") - def split_metadata_for_multistream( - self, - ms_split_config: MSAttentionMetadataSplitConfig, - ) -> list["AscendSFATorchairMetadata"]: - """Split metadata for multi-stream with AscendSFATorchairMetadata""" - return model_input_split_v1_mla_attn( - ms_split_config=ms_split_config, - attn_metadata=self, - _metadata_cls=AscendSFATorchairMetadata, - ) - M = TypeVar("M", bound=AscendSFATorchairMetadata) @@ -616,7 +602,6 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, is_prefill=is_prefill, is_decode=is_decode) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0c2ead8fae4..5a9c460d5ce 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -757,13 +757,13 @@ def get_default_buffer_config() -> dict: def calculate_dp_buffer_size() -> int: """ formula of dp buffer size: - dp_size + 2 (flags: with_prefill and enable_dbo) + dp_size + 1 (flags: with_prefill) """ from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() dp_size = vllm_config.parallel_config.data_parallel_size int32_size = torch.iinfo(torch.int32).bits // 8 - dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024)) + dp_buffer_size = math.ceil((dp_size + 1) * int32_size / (1024 * 1024)) return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6224eacbb1f..22d8f911e5b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -122,7 +122,6 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register -from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs @@ -859,8 +858,8 @@ def _init_mrope_positions(self, req_state: CachedRequestState): ) def _sync_metadata_across_dp( - self, num_tokens: int, with_prefill: bool, enable_dbo: bool - ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + self, num_tokens: int, + with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]: # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in # our case, we still need to sync the other two flags as well. So we need to # include them in the all_reduce operation, and more over, we CANNOT skip it @@ -868,31 +867,29 @@ def _sync_metadata_across_dp( # FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here # immediately once the other two flags are no longer needed. if self.dp_size == 1: - return num_tokens, None, with_prefill, enable_dbo + return num_tokens, None, with_prefill - # Sync num_tokens, with_prefill, enable_dbo across dp ranks + # Sync num_tokens, with_prefill across dp ranks num_tokens_tensor = torch.tensor([ num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) ], dtype=torch.int32, device="npu") - flags_tensor = torch.tensor( - [int(with_prefill), int(not enable_dbo)], - dtype=torch.int32, - device="npu") + flags_tensor = torch.tensor([int(with_prefill)], + dtype=torch.int32, + device="npu") packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) dist.all_reduce(packed_tensor, group=get_dp_group().device_group) # Unpack the results - num_tokens_across_dp = packed_tensor[:-2] - synced_flags = packed_tensor[-2:] + num_tokens_across_dp = packed_tensor[:-1] + synced_flags = packed_tensor[-1:] max_tokens_across_dp = torch.max(num_tokens_across_dp).item() global_with_prefill = bool(synced_flags[0]) - global_enable_dbo = not bool(synced_flags[1]) # Create a tensor for num_tokens_after_padding num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * @@ -900,28 +897,7 @@ def _sync_metadata_across_dp( device="cpu", dtype=torch.int32) - return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo - - def _check_dbo_is_valid(self, query_lens: torch.Tensor, - attn_state: AscendAttentionState, - num_tokens: int) -> bool: - # do the checks for dp + dbo - if attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - return False - # considering the case that one dp rank may enable dbo while others may not - if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: - return False - # TODO: remove it if token-level microbatch is enabled - [token_index, - seq_index] = compute_split_seq_index(query_lens, attn_state, - num_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - query_lens) or num_tokens < 256: - return False - return True + return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. @@ -1430,16 +1406,13 @@ def _prepare_inputs( ] self.query_lens = torch.from_numpy(num_scheduled_tokens) - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) # Get info across DP ranks. # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, # Otherwise, it's just max_tokens_across_dp_cpu - (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, - enable_dbo) = self._sync_metadata_across_dp(num_input_tokens, - with_prefill, enable_dbo) + (maybe_padded_num_tokens, num_tokens_across_dp, + with_prefill) = self._sync_metadata_across_dp(num_input_tokens, + with_prefill) # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens # We should consider removing maybe_padded_num_tokens later @@ -1710,7 +1683,6 @@ def _prepare_inputs( attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, - enable_dbo_across_dp=enable_dbo, is_only_prefill=bool(np.all(num_valid_tokens != 1)), max_query_len=max_num_scheduled_tokens, graph_pad_size=self.graph_pad_size, @@ -2604,8 +2576,9 @@ def _dummy_run( num_tokens = math.ceil(num_tokens / tp_size) * tp_size # Padding for DP - (num_tokens, num_tokens_across_dp, with_prefill, - _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) + (num_tokens, num_tokens_across_dp, + with_prefill) = self._sync_metadata_across_dp(num_tokens, + with_prefill) moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)