Skip to content

Commit 292e213

Browse files
authored
[main][refactor] refactor SequenceRowParallelOp forward (#3616)
### What this PR does / why we need it? This PR refactors SequenceRowParallelOp forward. In order to further expand the operator inclusion scope in dynamic judgment scenarios, this PR customizes the entire matmul computation and communication as a custom operator masking. With this refactor, it will support directly writing code such as common operation fusion into the `SequenceRowParallelOp` class's member function `matmul_and_reduce`, without the need to register more redundant custom masking operators. ### How was this patch tested? CI passed with existing test. Signed-off-by: rjg-lyh <[email protected]>
1 parent ca104ce commit 292e213

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

tests/ut/ops/test_linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import MagicMock, patch
55

66
import torch
7+
from vllm import config
78

89
from tests.ut.base import TestBase
910
from vllm_ascend import ascend_config
@@ -106,6 +107,9 @@ def test_mlp_optimize(self):
106107
linear(input_tensor)
107108

108109
def test_oproj_tp(self):
110+
111+
config._current_vllm_config = MagicMock()
112+
109113
ascend_config._ASCEND_CONFIG = MagicMock()
110114
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
111115

vllm_ascend/ops/linear.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch.nn as nn
2727
import torch_npu
2828
from torch.nn.parameter import Parameter
29+
from vllm.config import get_current_vllm_config
2930
from vllm.distributed import divide
3031
from vllm.model_executor.layers.linear import ( # noqa
3132
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
@@ -234,6 +235,13 @@ def __init__(
234235
return_bias: bool = True,
235236
disable_tp: bool = False,
236237
):
238+
compilation_config = get_current_vllm_config().compilation_config
239+
# TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete.
240+
if prefix in compilation_config.static_forward_context and \
241+
"visual" not in prefix:
242+
raise ValueError(f"Duplicate layer name: {prefix}")
243+
compilation_config.static_forward_context[prefix] = self
244+
237245
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
238246
disable_tp, prefix, self, "row")
239247
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group

vllm_ascend/ops/linear_op.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,23 @@ def apply_impl(
366366
input_parallel,
367367
bias=bias_)
368368
else:
369-
output_parallel = self.quant_method.apply(self.layer,
370-
input_parallel,
371-
bias=bias_)
372-
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
369+
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
370+
self.prefix)
373371

374372
output_bias = self.bias if self.skip_bias_add else None
375373
return output, output_bias
376374

375+
def matmul_and_reduce(self, input_parallel: torch.Tensor,
376+
bias_: Optional[Parameter]) -> torch.Tensor:
377+
assert self.quant_method is not None
378+
output_parallel = self.quant_method.apply(self.layer,
379+
input_parallel,
380+
bias=bias_)
381+
from vllm_ascend.ops.register_custom_ops import \
382+
_maybe_pad_and_reduce_impl
383+
output = _maybe_pad_and_reduce_impl(output_parallel)
384+
return output
385+
377386
def update_attrs(self):
378387
super().update_attrs()
379388
self.input_is_parallel = self.layer.input_is_parallel

vllm_ascend/ops/register_custom_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,31 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
235235
return tensor_model_parallel_all_reduce(final_hidden_states)
236236

237237

238+
def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
239+
layer_name: str) -> torch.Tensor:
240+
forward_context = get_forward_context()
241+
self = forward_context.no_compile_layers[layer_name]
242+
assert self.custom_op is not None
243+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
244+
output = self.custom_op.matmul_and_reduce(input_parallel, bias_)
245+
246+
return output
247+
248+
249+
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
250+
layer_name: str) -> torch.Tensor:
251+
forward_context = get_forward_context()
252+
self = forward_context.no_compile_layers[layer_name]
253+
num_tokens = input_parallel.size(0)
254+
if forward_context.sp_enabled:
255+
num_tokens = num_tokens // self.tp_size
256+
output = torch.empty(size=(num_tokens, self.output_size_per_partition),
257+
device=input_parallel.device,
258+
dtype=input_parallel.dtype)
259+
260+
return output
261+
262+
238263
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
239264
op_func=_maybe_all_gather_and_maybe_unpad_impl,
240265
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
@@ -282,3 +307,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
282307
fake_impl=lambda x: x,
283308
mutates_args=[],
284309
dispatch_key="PrivateUse1")
310+
311+
direct_register_custom_op(op_name="matmul_and_reduce",
312+
op_func=_matmul_and_reduce_impl,
313+
fake_impl=_matmul_and_reduce_impl_fake,
314+
mutates_args=[],
315+
dispatch_key="PrivateUse1")

0 commit comments

Comments
 (0)