1515# limitations under the License.
1616#
1717
18+ from itertools import accumulate
1819from typing import Optional , Tuple
1920
2021import torch
2122from torch import nn
2223from torch .nn .parameter import Parameter
24+ from vllm .config import get_current_vllm_config
2325from vllm .distributed import divide , tensor_model_parallel_all_reduce
2426from vllm .distributed .parallel_state import get_tp_group
27+ from vllm .forward_context import get_forward_context
2528from vllm .model_executor .layers .logits_processor import LogitsProcessor
2629from vllm .model_executor .layers .quantization .base_config import (
2730 QuantizationConfig , QuantizeMethodBase , method_has_implemented_embedding )
3033 VocabParallelEmbedding , pad_vocab_size )
3134from vllm .model_executor .utils import set_weight_attrs
3235
33- from vllm_ascend .distributed .parallel_state import get_lmhead_tp_group
34- from vllm_ascend .utils import lmhead_tp_enable
36+ from vllm_ascend .ascend_config import get_ascend_config
37+ from vllm_ascend .distributed .parallel_state import (get_embed_tp_group ,
38+ get_lmhead_tp_group )
39+ from vllm_ascend .utils import embedding_tp_enable , lmhead_tp_enable
3540
3641
3742class AscendVocabParallelEmbedding (VocabParallelEmbedding ):
@@ -51,8 +56,15 @@ def __init__(self,
5156 prefix : str = "" ):
5257 nn .Module .__init__ (self )
5358
59+ self .forward_type = None
5460 if lmhead_tp_enable () and prefix .find ("lm_head" ) != - 1 :
5561 self .comm_group = get_lmhead_tp_group ()
62+ elif embedding_tp_enable () and prefix .find ("embed_tokens" ) != - 1 :
63+ self .comm_group = get_embed_tp_group ()
64+ self .forward_type = "embed_tp"
65+ self .is_decode_only = get_current_vllm_config (
66+ ).kv_transfer_config .is_kv_consumer
67+ self .forward_type = "embed_tp"
5668 else :
5769 self .comm_group = get_tp_group ()
5870
@@ -146,6 +158,54 @@ def _get_masked_input_and_mask(
146158 return input_ , ~ vocab_mask
147159
148160 def forward (self , input_ ):
161+ if self .forward_type == "embed_tp" :
162+ return self ._forward_embed_tp (input_ )
163+ else :
164+ return self ._forward_origin (input_ )
165+
166+ def _forward_embed_tp (self , input_ ):
167+ if get_ascend_config (
168+ ).torchair_graph_config .enabled is False and not self .is_decode_only :
169+ cu_tokens_across_dp_cpu = get_forward_context (
170+ ).dp_metadata .cu_tokens_across_dp_cpu
171+ global_dp_batch_size = torch .diff (
172+ cu_tokens_across_dp_cpu ,
173+ prepend = cu_tokens_across_dp_cpu .new_zeros (1 ))
174+ embedd_group_batch_size = [
175+ global_dp_batch_size [x ] for x in self .comm_group .ranks
176+ ]
177+ # Gather inputs from all embed TP ranks
178+ gathered_input = [
179+ torch .empty (batch_size , dtype = input_ .dtype , device = 'npu' )
180+ for batch_size in embedd_group_batch_size
181+ ]
182+ torch .distributed .all_gather (gathered_input ,
183+ input_ ,
184+ group = self .comm_group .device_group )
185+ complete_input = torch .cat (gathered_input , dim = 0 )
186+ else :
187+ complete_input = self .comm_group .all_gather (input_ , dim = 0 )
188+ embedd_group_batch_size = [input_ .size (0 )
189+ ] * self .comm_group .world_size
190+ # Mask input for vocab sharding
191+ masked_input , input_mask = self ._get_masked_input_and_mask (
192+ complete_input , self .shard_indices .org_vocab_start_index ,
193+ self .shard_indices .org_vocab_end_index ,
194+ self .shard_indices .num_org_vocab_padding ,
195+ self .shard_indices .added_vocab_start_index ,
196+ self .shard_indices .added_vocab_end_index )
197+ complete_output = self .quant_method .embedding (self ,
198+ masked_input .long ())
199+ complete_output .masked_fill_ (input_mask .unsqueeze (- 1 ), 0 )
200+ output = self .comm_group .all_reduce (complete_output )
201+ # Slice output to return only local batch portion
202+ prefix_sum = list (accumulate (embedd_group_batch_size ))
203+ start_idx = prefix_sum [self .tp_rank - 1 ] if self .tp_rank > 0 else 0
204+ end_idx = prefix_sum [self .tp_rank ]
205+ output = output [start_idx :end_idx ]
206+ return output
207+
208+ def _forward_origin (self , input_ ):
149209 if self .tp_size > 1 :
150210 # Build the mask.
151211 masked_input , input_mask = self ._get_masked_input_and_mask (
0 commit comments