|
| 1 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 2 | +# you may not use this file except in compliance with the License. |
| 3 | +# You may obtain a copy of the License at |
| 4 | +# |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# |
| 7 | +# Unless required by applicable law or agreed to in writing, software |
| 8 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +# See the License for the specific language governing permissions and |
| 11 | +# limitations under the License. |
| 12 | +# This file is a part of the vllm-ascend project. |
| 13 | + |
| 14 | +from unittest.mock import MagicMock |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import pytest |
| 18 | +import torch |
| 19 | + |
| 20 | +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner |
| 21 | + |
| 22 | + |
| 23 | +@pytest.mark.parametrize( |
| 24 | + "pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none", |
| 25 | + [ |
| 26 | + (1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False), |
| 27 | + (1, 2, 3, [20, 30, 40], 1, False, 50, True), |
| 28 | + (2, 1, 4, [5, 10, 40, 60], 2, False, 100, True), |
| 29 | + (2, 1, 4, [5, 10, 40, 60], 2, True, 100, True), |
| 30 | + (2, 1, 3, [5, 10, 15], 3, False, 50, True), |
| 31 | + (2, 1, 3, [40, 50, 60], 0, False, 150, True), |
| 32 | + ]) |
| 33 | +def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, |
| 34 | + num_decodes, use_mla, total_tokens, |
| 35 | + expect_not_none): |
| 36 | + mock_runner = MagicMock(spec=NPUModelRunner) |
| 37 | + mock_runner.pcp_size = pcp_size |
| 38 | + mock_runner.dcp_size = dcp_size |
| 39 | + mock_runner.decode_threshold = 4 |
| 40 | + mock_runner.pcp_rank = 0 |
| 41 | + mock_runner.device = torch.device('cpu') |
| 42 | + mock_runner.dtype = torch.float32 |
| 43 | + |
| 44 | + mock_runner.parallel_config = MagicMock() |
| 45 | + mock_runner.parallel_config.cp_kv_cache_interleave_size = 64 |
| 46 | + |
| 47 | + mock_runner.vllm_config = MagicMock() |
| 48 | + mock_runner.vllm_config.model_config = MagicMock() |
| 49 | + mock_runner.vllm_config.model_config.use_mla = use_mla |
| 50 | + |
| 51 | + mock_runner.input_batch = MagicMock() |
| 52 | + mock_runner.input_batch.num_reqs = num_reqs |
| 53 | + |
| 54 | + num_computed_tokens = [] |
| 55 | + num_prompt_tokens = [] |
| 56 | + num_tokens = [] |
| 57 | + |
| 58 | + for i in range(num_reqs): |
| 59 | + if i < num_decodes: |
| 60 | + num_computed_tokens.append(query_lens[i]) |
| 61 | + num_prompt_tokens.append(query_lens[i] // 2) |
| 62 | + num_tokens.append(query_lens[i]) |
| 63 | + else: |
| 64 | + num_computed_tokens.append(0) |
| 65 | + num_prompt_tokens.append(query_lens[i]) |
| 66 | + num_tokens.append(query_lens[i]) |
| 67 | + |
| 68 | + mock_runner.input_batch.num_computed_tokens_cpu = torch.tensor( |
| 69 | + num_computed_tokens) |
| 70 | + mock_runner.input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) |
| 71 | + mock_runner.input_batch.num_tokens = torch.tensor(num_tokens) |
| 72 | + |
| 73 | + mock_runner.query_lens = torch.tensor(query_lens) |
| 74 | + |
| 75 | + mock_runner._get_cp_local_seq_lens = NPUModelRunner._get_cp_local_seq_lens.__get__( |
| 76 | + mock_runner, NPUModelRunner) |
| 77 | + |
| 78 | + mock_runner.pcp_allgather_restore_idx = torch.arange(total_tokens * 2) |
| 79 | + mock_runner.cp_kv_recover_idx_for_chunk = torch.arange(total_tokens) |
| 80 | + |
| 81 | + mock_runner.long_seq_metadata = None |
| 82 | + mock_runner.num_actual_tokens_pcp_padded = 0 |
| 83 | + mock_runner.kv_idx_names = {} |
| 84 | + mock_runner.extra_long_seq_kwargs = {} |
| 85 | + mock_runner.attn_mask = None |
| 86 | + mock_runner.q_head_idx_tensor = None |
| 87 | + mock_runner.q_tail_idx_tensor = None |
| 88 | + mock_runner.q_full_idx = None |
| 89 | + |
| 90 | + method = NPUModelRunner._generate_pcp_metadata.__get__( |
| 91 | + mock_runner, NPUModelRunner) |
| 92 | + result = method(total_tokens) |
| 93 | + |
| 94 | + if not expect_not_none: |
| 95 | + assert result is None, f"Expected to return None, but got {type(result)}" |
| 96 | + else: |
| 97 | + assert result is not None, "Expected to return a metadata object, but got None." |
| 98 | + |
| 99 | + assert hasattr(result, 'num_actual_tokens_pcp_padded') |
| 100 | + assert hasattr(result, 'num_computed_tokens_of_pcp_dcp') |
| 101 | + |
| 102 | + if pcp_size > 1: |
| 103 | + assert hasattr(result, 'pcp_allgather_restore_idx') |
| 104 | + |
| 105 | + has_prefill_requests = (num_reqs - num_decodes) > 0 |
| 106 | + if has_prefill_requests: |
| 107 | + assert hasattr(result, 'q_head_idx_tensor') |
| 108 | + assert hasattr(result, 'q_tail_idx_tensor') |
| 109 | + assert hasattr(result, 'q_full_idx') |
| 110 | + assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor') |
| 111 | + assert hasattr(result, 'kv_with_q_head_mask_idx_tensor') |
| 112 | + assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor') |
| 113 | + assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor') |
| 114 | + assert hasattr(result, 'attn_mask_seqlens') |
| 115 | + assert hasattr(result, 'head_attn_nomask_seqlens') |
| 116 | + assert hasattr(result, 'tail_attn_nomask_seqlens') |
| 117 | + |
| 118 | + if hasattr(result, 'pcp_prefill_mask' |
| 119 | + ) and result.pcp_prefill_mask is not None: |
| 120 | + if use_mla: |
| 121 | + assert result.pcp_prefill_mask.shape == (512, 512) |
| 122 | + else: |
| 123 | + assert result.pcp_prefill_mask.shape == (2048, 2048) |
| 124 | + else: |
| 125 | + if hasattr(result, 'pcp_prefill_mask'): |
| 126 | + if result.pcp_prefill_mask is not None: |
| 127 | + if use_mla: |
| 128 | + assert result.pcp_prefill_mask.shape == (512, 512) |
| 129 | + else: |
| 130 | + assert result.pcp_prefill_mask.shape == (2048, |
| 131 | + 2048) |
| 132 | + |
| 133 | + |
| 134 | +def test_generate_pcp_metadata_edge_cases(): |
| 135 | + mock_runner = MagicMock() |
| 136 | + mock_runner.pcp_size = 2 |
| 137 | + mock_runner.dcp_size = 1 |
| 138 | + mock_runner.input_batch = MagicMock() |
| 139 | + mock_runner.input_batch.num_reqs = 0 |
| 140 | + mock_runner.query_lens = torch.tensor([10, 20, 30]) |
| 141 | + |
| 142 | + assert (mock_runner.input_batch.num_reqs |
| 143 | + or mock_runner.query_lens.size(0)) == 3 |
| 144 | + |
| 145 | + mock_runner.input_batch.num_reqs = 100 |
| 146 | + mock_runner.query_lens = torch.ones(100) * 1000 |
| 147 | + |
| 148 | + for rank in [0, 1]: |
| 149 | + mock_runner.pcp_rank = rank |
| 150 | + q_head_chunk_id = rank |
| 151 | + q_tail_chunk_id = 2 * 2 - 1 - rank |
| 152 | + assert q_head_chunk_id == rank |
| 153 | + assert q_tail_chunk_id == 3 - rank |
| 154 | + |
| 155 | + |
| 156 | +def test_pcp_allgather_restore_idx_slicing(): |
| 157 | + mock_runner = MagicMock() |
| 158 | + mock_runner.pcp_size = 2 |
| 159 | + mock_runner.pcp_allgather_restore_idx = torch.arange(1000) |
| 160 | + |
| 161 | + total_num_scheduled_tokens = 200 |
| 162 | + num_actual_tokens_pcp_padded = total_num_scheduled_tokens * 2 |
| 163 | + |
| 164 | + expected_slice = mock_runner.pcp_allgather_restore_idx[: |
| 165 | + num_actual_tokens_pcp_padded] |
| 166 | + assert len(expected_slice) == 400 |
| 167 | + assert expected_slice[0] == 0 |
| 168 | + assert expected_slice[-1] == 399 |
| 169 | + |
| 170 | + |
| 171 | +@pytest.mark.parametrize( |
| 172 | + "tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens", |
| 173 | + [ |
| 174 | + # Case 1: prefill only |
| 175 | + ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]), |
| 176 | +
|
| 177 | + # Case 2: mix prefill and decode |
| 178 | + ([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, [8, 4, 4]), |
| 179 | +
|
| 180 | + # Case 3: request which need to be padded |
| 181 | + ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]), |
| 182 | +
|
| 183 | + # Case 4: single request |
| 184 | + ([10], 1, [0], [10], 4, 0, [4]), |
| 185 | + ]) |
| 186 | +def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, |
| 187 | + num_prompt_tokens, pcp_size, pcp_rank, |
| 188 | + expected_pcp_tokens): |
| 189 | + mock_runner = MagicMock(spec=NPUModelRunner) |
| 190 | + mock_runner.pcp_size = pcp_size |
| 191 | + mock_runner.pcp_rank = pcp_rank |
| 192 | + |
| 193 | + mock_runner.input_batch = MagicMock() |
| 194 | + mock_runner.input_batch.num_reqs = num_reqs |
| 195 | + mock_runner.input_batch.num_computed_tokens_cpu = np.array( |
| 196 | + num_computed_tokens, dtype=np.int32) |
| 197 | + mock_runner.input_batch.num_prompt_tokens = np.array(num_prompt_tokens, |
| 198 | + dtype=np.int32) |
| 199 | + |
| 200 | + mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) |
| 201 | + |
| 202 | + mock_runner.num_pcp_pads = [0] * num_reqs |
| 203 | + mock_runner.arange_np = np.arange(10000) |
| 204 | + |
| 205 | + mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( |
| 206 | + mock_runner, NPUModelRunner) |
| 207 | + mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( |
| 208 | + mock_runner, NPUModelRunner) |
| 209 | + |
| 210 | + pcp_tokens_result, positions_result, unpad_mask_result = mock_runner._update_tokens_for_pcp( |
| 211 | + tokens) |
| 212 | + |
| 213 | + assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ |
| 214 | + f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" |
| 215 | + |
| 216 | + total_pcp_tokens: int = np.sum(pcp_tokens_result) |
| 217 | + assert positions_result.shape == (total_pcp_tokens,), \ |
| 218 | + f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}" |
| 219 | + |
| 220 | + padded_tokens = [ |
| 221 | + (t + 2 * pcp_size - 1) // (2 * pcp_size) * |
| 222 | + (2 * pcp_size) if num_computed_tokens[i] == 0 else t * pcp_size |
| 223 | + for i, t in enumerate(tokens) |
| 224 | + ] |
| 225 | + total_padded_tokens: int = np.sum(padded_tokens) |
| 226 | + assert unpad_mask_result.shape[0] == total_padded_tokens, \ |
| 227 | + f"unpad_mask size mismatch: expected {total_padded_tokens}, got {unpad_mask_result.shape[0]}" |
| 228 | + |
| 229 | + |
| 230 | +def test_update_tokens_for_pcp_with_padding(): |
| 231 | + mock_runner = MagicMock(spec=NPUModelRunner) |
| 232 | + mock_runner.pcp_size = 4 |
| 233 | + mock_runner.pcp_rank = 0 |
| 234 | + |
| 235 | + mock_runner.arange_np = np.arange(10000) |
| 236 | + |
| 237 | + mock_runner.input_batch = MagicMock() |
| 238 | + mock_runner.input_batch.num_reqs = 3 |
| 239 | + mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0, 0], |
| 240 | + dtype=np.int32) |
| 241 | + mock_runner.input_batch.num_prompt_tokens = np.array([5, 9, 13], |
| 242 | + dtype=np.int32) |
| 243 | + |
| 244 | + mock_runner.num_pcp_pads = [0, 0, 0] |
| 245 | + mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) |
| 246 | + |
| 247 | + mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( |
| 248 | + mock_runner, NPUModelRunner) |
| 249 | + mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( |
| 250 | + mock_runner, NPUModelRunner) |
| 251 | + |
| 252 | + tokens = [5, 9, 13] |
| 253 | + |
| 254 | + pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp( |
| 255 | + tokens) |
| 256 | + |
| 257 | + expected_pcp_tokens = [2, 4, 4] |
| 258 | + assert np.array_equal(pcp_tokens, expected_pcp_tokens), \ |
| 259 | + f"Expected {expected_pcp_tokens}, got {pcp_tokens}" |
| 260 | + |
| 261 | + expected_pads = [3, 7, 3] |
| 262 | + assert np.array_equal(mock_runner.num_pcp_pads, expected_pads), \ |
| 263 | + f"Expected padding {expected_pads}, got {mock_runner.num_pcp_pads}" |
| 264 | + |
| 265 | + |
| 266 | +def test_update_tokens_for_pcp_unpad_mask(): |
| 267 | + mock_runner = MagicMock(spec=NPUModelRunner) |
| 268 | + mock_runner.pcp_size = 4 |
| 269 | + mock_runner.pcp_rank = 0 |
| 270 | + |
| 271 | + mock_runner.arange_np = np.arange(10000) |
| 272 | + |
| 273 | + mock_runner.input_batch = MagicMock() |
| 274 | + mock_runner.input_batch.num_reqs = 2 |
| 275 | + mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0], |
| 276 | + dtype=np.int32) |
| 277 | + mock_runner.input_batch.num_prompt_tokens = np.array([5, 7], |
| 278 | + dtype=np.int32) |
| 279 | + |
| 280 | + mock_runner.num_pcp_pads = [0, 0] |
| 281 | + mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) |
| 282 | + |
| 283 | + mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( |
| 284 | + mock_runner, NPUModelRunner) |
| 285 | + mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( |
| 286 | + mock_runner, NPUModelRunner) |
| 287 | + |
| 288 | + tokens = [5, 7] |
| 289 | + |
| 290 | + pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp( |
| 291 | + tokens) |
| 292 | + |
| 293 | + assert unpad_mask.dtype == torch.bool, \ |
| 294 | + f"unpad_mask should be bool, got {unpad_mask.dtype}" |
| 295 | + |
| 296 | + padded_tokens = [8, 8] |
| 297 | + expected_length = sum(padded_tokens) |
| 298 | + assert unpad_mask.shape[0] == expected_length, \ |
| 299 | + f"unpad_mask length mismatch: expected {expected_length}, got {unpad_mask.shape[0]}" |
| 300 | + |
| 301 | + expected_mask = [True] * 5 + [False] * 3 + [True] * 7 + [False] * 1 |
| 302 | + actual_mask = unpad_mask.numpy().tolist() |
| 303 | + assert actual_mask == expected_mask, \ |
| 304 | + f"unpad_mask incorrect. Expected {expected_mask}, got {actual_mask}" |
0 commit comments