Skip to content

Commit 0fe35e0

Browse files
zzh02232027zzhx1
authored andcommitted
[feat] add custom embedding tensor parallel
Signed-off-by: zzhx1 <[email protected]>
1 parent 07f4710 commit 0fe35e0

File tree

4 files changed

+105
-2
lines changed

4 files changed

+105
-2
lines changed

vllm_ascend/ascend_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ def __init__(self, vllm_config):
9393
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
9494
)
9595

96+
self.embedding_tensor_parallel_size = additional_config.get(
97+
"embedding_tensor_parallel_size", None)
98+
if self.embedding_tensor_parallel_size is not None:
99+
logger.info(
100+
f"Enable embedding_tensor_parallel_size = {self.embedding_tensor_parallel_size} in pure DP scenario"
101+
)
102+
if vllm_config.parallel_config.tensor_parallel_size != 1:
103+
raise AssertionError(
104+
"embedding_tensor_parallel_size is only supported in the pure DP scenario"
105+
)
106+
96107

97108
class TorchairGraphConfig:
98109
"""

vllm_ascend/distributed/parallel_state.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_MLP_TP: Optional[GroupCoordinator] = None
1414
_OTP: Optional[GroupCoordinator] = None
1515
_LMTP: Optional[GroupCoordinator] = None
16+
_EMBED_TP: Optional[GroupCoordinator] = None
1617

1718

1819
def get_mc2_group() -> GroupCoordinator:
@@ -37,6 +38,11 @@ def get_mlp_tp_group() -> GroupCoordinator:
3738
return _MLP_TP
3839

3940

41+
def get_embed_tp_group() -> GroupCoordinator:
42+
assert _EMBED_TP is not None, ("emtp group is not initialized")
43+
return _EMBED_TP
44+
45+
4046
def model_parallel_initialized():
4147
return (_MC2 is not None)
4248

@@ -111,6 +117,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
111117
backend,
112118
group_name="lmheadtp")
113119

120+
embedding_tensor_parallel_size = get_ascend_config(
121+
).embedding_tensor_parallel_size
122+
if embedding_tensor_parallel_size is not None:
123+
group_ranks = []
124+
global _EMBED_TP
125+
num_embedding_tensor_parallel_groups: int = (
126+
world_size // embedding_tensor_parallel_size)
127+
for i in range(num_embedding_tensor_parallel_groups):
128+
ranks = list(
129+
range(i * embedding_tensor_parallel_size,
130+
(i + 1) * embedding_tensor_parallel_size))
131+
group_ranks.append(ranks)
132+
_EMBED_TP = init_model_parallel_group(group_ranks,
133+
get_world_group().local_rank,
134+
backend,
135+
group_name="emtp")
136+
114137

115138
def get_mlp_tensor_model_parallel_world_size():
116139
"""Return world size for the tensor model parallel group."""
@@ -142,3 +165,8 @@ def destroy_ascend_model_parallel():
142165
if _OTP:
143166
_OTP.destroy()
144167
_OTP = None
168+
169+
global _EMBED_TP
170+
if _EMBED_TP:
171+
_EMBED_TP.destroy()
172+
_EMBED_TP = None

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
# limitations under the License.
1616
#
1717

18+
from itertools import accumulate
1819
from typing import Optional, Tuple
1920

2021
import torch
2122
from torch import nn
2223
from torch.nn.parameter import Parameter
24+
from vllm.config import get_current_vllm_config
2325
from vllm.distributed import divide, tensor_model_parallel_all_reduce
2426
from vllm.distributed.parallel_state import get_tp_group
27+
from vllm.forward_context import get_forward_context
2528
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2629
from vllm.model_executor.layers.quantization.base_config import (
2730
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
@@ -30,8 +33,10 @@
3033
VocabParallelEmbedding, pad_vocab_size)
3134
from vllm.model_executor.utils import set_weight_attrs
3235

33-
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
34-
from vllm_ascend.utils import lmhead_tp_enable
36+
from vllm_ascend.ascend_config import get_ascend_config
37+
from vllm_ascend.distributed.parallel_state import (get_embed_tp_group,
38+
get_lmhead_tp_group)
39+
from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable
3540

3641

3742
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
@@ -51,8 +56,15 @@ def __init__(self,
5156
prefix: str = ""):
5257
nn.Module.__init__(self)
5358

59+
self.forward_type = None
5460
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
5561
self.comm_group = get_lmhead_tp_group()
62+
elif embedding_tp_enable() and prefix.find("embed_tokens") != -1:
63+
self.comm_group = get_embed_tp_group()
64+
self.forward_type = "embed_tp"
65+
self.is_decode_only = get_current_vllm_config(
66+
).kv_transfer_config.is_kv_consumer
67+
self.forward_type = "embed_tp"
5668
else:
5769
self.comm_group = get_tp_group()
5870

@@ -146,6 +158,54 @@ def _get_masked_input_and_mask(
146158
return input_, ~vocab_mask
147159

148160
def forward(self, input_):
161+
if self.forward_type == "embed_tp":
162+
return self._forward_embed_tp(input_)
163+
else:
164+
return self._forward_origin(input_)
165+
166+
def _forward_embed_tp(self, input_):
167+
if get_ascend_config(
168+
).torchair_graph_config.enabled is False and not self.is_decode_only:
169+
cu_tokens_across_dp_cpu = get_forward_context(
170+
).dp_metadata.cu_tokens_across_dp_cpu
171+
global_dp_batch_size = torch.diff(
172+
cu_tokens_across_dp_cpu,
173+
prepend=cu_tokens_across_dp_cpu.new_zeros(1))
174+
embedd_group_batch_size = [
175+
global_dp_batch_size[x] for x in self.comm_group.ranks
176+
]
177+
# Gather inputs from all embed TP ranks
178+
gathered_input = [
179+
torch.empty(batch_size, dtype=input_.dtype, device='npu')
180+
for batch_size in embedd_group_batch_size
181+
]
182+
torch.distributed.all_gather(gathered_input,
183+
input_,
184+
group=self.comm_group.device_group)
185+
complete_input = torch.cat(gathered_input, dim=0)
186+
else:
187+
complete_input = self.comm_group.all_gather(input_, dim=0)
188+
embedd_group_batch_size = [input_.size(0)
189+
] * self.comm_group.world_size
190+
# Mask input for vocab sharding
191+
masked_input, input_mask = self._get_masked_input_and_mask(
192+
complete_input, self.shard_indices.org_vocab_start_index,
193+
self.shard_indices.org_vocab_end_index,
194+
self.shard_indices.num_org_vocab_padding,
195+
self.shard_indices.added_vocab_start_index,
196+
self.shard_indices.added_vocab_end_index)
197+
complete_output = self.quant_method.embedding(self,
198+
masked_input.long())
199+
complete_output.masked_fill_(input_mask.unsqueeze(-1), 0)
200+
output = self.comm_group.all_reduce(complete_output)
201+
# Slice output to return only local batch portion
202+
prefix_sum = list(accumulate(embedd_group_batch_size))
203+
start_idx = prefix_sum[self.tp_rank - 1] if self.tp_rank > 0 else 0
204+
end_idx = prefix_sum[self.tp_rank]
205+
output = output[start_idx:end_idx]
206+
return output
207+
208+
def _forward_origin(self, input_):
149209
if self.tp_size > 1:
150210
# Build the mask.
151211
masked_input, input_mask = self._get_masked_input_and_mask(

vllm_ascend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,10 @@ def lmhead_tp_enable() -> bool:
574574
return get_ascend_config().lmhead_tensor_parallel_size is not None
575575

576576

577+
def embedding_tp_enable() -> bool:
578+
return get_ascend_config().embedding_tensor_parallel_size is not None
579+
580+
577581
def oproj_tp_enable() -> bool:
578582
return get_ascend_config().oproj_tensor_parallel_size is not None
579583

0 commit comments

Comments
 (0)