2121from torch import nn
2222from torch .nn .parameter import Parameter
2323from vllm .distributed import divide , tensor_model_parallel_all_reduce
24- from vllm .distributed .parallel_state import get_tp_group
25- import torch . distributed as dist
24+ from vllm .distributed .parallel_state import get_dp_group , get_tp_group
25+ from vllm . forward_context import get_forward_context
2626from vllm .model_executor .layers .logits_processor import LogitsProcessor
2727from vllm .model_executor .layers .quantization .base_config import (
2828 QuantizationConfig , QuantizeMethodBase , method_has_implemented_embedding )
2929from vllm .model_executor .layers .vocab_parallel_embedding import (
3030 DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , UnquantizedEmbeddingMethod ,
3131 VocabParallelEmbedding , pad_vocab_size )
3232from vllm .model_executor .utils import set_weight_attrs
33- from vllm .distributed .parallel_state import get_dp_group
34- from vllm .forward_context import get_forward_context
3533from vllm .utils import logger
3634
37- from vllm_ascend .distributed .parallel_state import get_lmhead_tp_group , get_emtp_group
38- from vllm_ascend .utils import lmhead_tp_enable , embedding_tp_enable
35+ from vllm_ascend .distributed .parallel_state import (get_emtp_group ,
36+ get_lmhead_tp_group )
37+ from vllm_ascend .utils import embedding_tp_enable , lmhead_tp_enable
3938
4039
4140class AscendVocabParallelEmbedding (VocabParallelEmbedding ):
@@ -150,30 +149,41 @@ def _get_masked_input_and_mask(
150149 input_ = vocab_mask * (input_ - valid_offset )
151150 return input_ , ~ vocab_mask
152151
153- def _get_local_batch_slice (self , tensor : torch .Tensor ,
154- batch_sizes : list ,
155- local_batch_size : int ,
156- rank : int ) -> torch .Tensor :
152+ def _get_local_batch_slice (self , tensor : torch .Tensor , batch_sizes : list ,
153+ local_batch_size : int ,
154+ rank : int ) -> torch .Tensor :
157155 end_idx = batch_sizes [rank ]
158156 start_idx = end_idx - local_batch_size
159157 return tensor [start_idx :end_idx ]
160-
158+
161159 def forward (self , input_ ):
162160 if embedding_tp_enable ():
163- logger .info (f"rank:{ get_dp_group ().rank_in_group } embedding_tp_enable" )
161+ logger .info (
162+ f"rank:{ get_dp_group ().rank_in_group } embedding_tp_enable" )
164163 return self ._forward_embed_tp (input_ )
165164 else :
166165 return self ._forward_normal (input_ )
167-
166+
168167 def _forward_embed_tp (self , input_ ):
169- cu_tokens_across_dp_cpu = get_forward_context ().dp_metadata .cu_tokens_across_dp_cpu
170- global_dp_batch_size = torch .diff (cu_tokens_across_dp_cpu , prepend = cu_tokens_across_dp_cpu .new_zeros (1 ))
171- logger .info (f"debug input_: { input_ .shape } \n global_dp_batch_size: { global_dp_batch_size } \n " )
172- lmhead_group_batch_size = [global_dp_batch_size [x ] for x in get_lmhead_tp_group ().ranks ]
168+ cu_tokens_across_dp_cpu = get_forward_context (
169+ ).dp_metadata .cu_tokens_across_dp_cpu
170+ global_dp_batch_size = torch .diff (
171+ cu_tokens_across_dp_cpu ,
172+ prepend = cu_tokens_across_dp_cpu .new_zeros (1 ))
173+ logger .info (
174+ f"debug input_: { input_ .shape } \n global_dp_batch_size: { global_dp_batch_size } \n "
175+ )
176+ lmhead_group_batch_size = [
177+ global_dp_batch_size [x ] for x in get_lmhead_tp_group ().ranks
178+ ]
173179 local_batch_size = input_ .size (0 )
174- gathered_input = [torch .empty (batch_size , dtype = input_ .dtype , device = 'npu' ) for batch_size in lmhead_group_batch_size ]
175- torch .distributed .all_gather (
176- gathered_input , input_ , group = get_lmhead_tp_group ().device_group )
180+ gathered_input = [
181+ torch .empty (batch_size , dtype = input_ .dtype , device = 'npu' )
182+ for batch_size in lmhead_group_batch_size
183+ ]
184+ torch .distributed .all_gather (gathered_input ,
185+ input_ ,
186+ group = get_lmhead_tp_group ().device_group )
177187 complete_input = torch .cat (gathered_input , dim = 0 )
178188 masked_input , input_mask = self ._get_masked_input_and_mask (
179189 complete_input , self .shard_indices .org_vocab_start_index ,
@@ -182,43 +192,48 @@ def _forward_embed_tp(self, input_):
182192 self .shard_indices .added_vocab_start_index ,
183193 self .shard_indices .added_vocab_end_index )
184194 logger .info (f"all_gather_down complete_input: { complete_input .shape } " )
185-
195+
186196 output = self .quant_method .embedding (self , masked_input .long ())
187197 output .masked_fill_ (input_mask .unsqueeze (- 1 ), 0 )
188198 output = tensor_model_parallel_all_reduce (output )
189199 # output = output[lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]-local_batch_size :lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]]
190200 # Extract the local batch portion from the gathered output
191201 lmhead_tp_group = get_lmhead_tp_group ()
192- output = self ._get_local_batch_slice (
193- output ,
194- lmhead_group_batch_size ,
195- local_batch_size ,
196- lmhead_tp_group .rank_in_group
197- )
198- logger .info (f"rank:{ get_dp_group ().rank_in_group } output: { output .shape } " )
202+ output = self ._get_local_batch_slice (output , lmhead_group_batch_size ,
203+ local_batch_size ,
204+ lmhead_tp_group .rank_in_group )
205+ logger .info (
206+ f"rank:{ get_dp_group ().rank_in_group } output: { output .shape } " )
199207 return output
200208
201209 def _forward_normal (self , input_ ):
202210 if self .tp_size > 1 :
203211 # Build the mask.
204- masked_input , input_mask = get_masked_input_and_mask (
212+ masked_input , input_mask = self . _get_masked_input_and_mask (
205213 input_ , self .shard_indices .org_vocab_start_index ,
206214 self .shard_indices .org_vocab_end_index ,
207215 self .shard_indices .num_org_vocab_padding ,
208216 self .shard_indices .added_vocab_start_index ,
209217 self .shard_indices .added_vocab_end_index )
210218 else :
211219 masked_input = input_
212- logger .info (f"rank:{ get_dp_group ().rank_in_group } masked_input:{ masked_input .shape } " )
220+ logger .info (
221+ f"rank:{ get_dp_group ().rank_in_group } masked_input:{ masked_input .shape } "
222+ )
213223 # Get the embeddings.
214- output_parallel = self .quant_method .embedding (self , masked_input .long ())
215- logger .info (f"rank:{ get_dp_group ().rank_in_group } output_parallel:{ output_parallel .shape } " )
224+ output_parallel = self .quant_method .embedding (self ,
225+ masked_input .long ())
226+ logger .info (
227+ f"rank:{ get_dp_group ().rank_in_group } output_parallel:{ output_parallel .shape } "
228+ )
216229 # Mask the output embedding.
217230 if self .tp_size > 1 :
218231 output_parallel .masked_fill_ (input_mask .unsqueeze (- 1 ), 0 )
219232 # Reduce across all the model parallel GPUs.
220233 output = tensor_model_parallel_all_reduce (output_parallel )
221- logger .info (f"rank:{ get_dp_group ().rank_in_group } forward_normal output:{ output .shape } " )
234+ logger .info (
235+ f"rank:{ get_dp_group ().rank_in_group } forward_normal output:{ output .shape } "
236+ )
222237 return output
223238
224239
0 commit comments