@@ -83,7 +83,9 @@ def load_balancing_loss_func(
8383
8484 if isinstance (gate_logits , tuple ):
8585 compute_device = gate_logits [0 ].device
86- concatenated_gate_logits = torch .cat ([layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0 )
86+ concatenated_gate_logits = torch .cat (
87+ [layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0
88+ )
8789
8890 routing_weights = torch .nn .functional .softmax (concatenated_gate_logits , dim = - 1 )
8991
@@ -99,20 +101,24 @@ def load_balancing_loss_func(
99101 router_prob_per_expert = torch .mean (routing_weights , dim = 0 )
100102 else :
101103 batch_size , sequence_length = attention_mask .shape
102- num_hidden_layers = concatenated_gate_logits .shape [0 ] // (batch_size * sequence_length )
104+ num_hidden_layers = concatenated_gate_logits .shape [0 ] // (
105+ batch_size * sequence_length
106+ )
103107
104108 # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
105109 expert_attention_mask = (
106110 attention_mask [None , :, :, None , None ]
107- .expand ((num_hidden_layers , batch_size , sequence_length , top_k , num_experts ))
111+ .expand (
112+ (num_hidden_layers , batch_size , sequence_length , top_k , num_experts )
113+ )
108114 .reshape (- 1 , top_k , num_experts )
109115 .to (compute_device )
110116 )
111117
112118 # Compute the percentage of tokens routed to each experts
113- tokens_per_expert = torch .sum (expert_mask . float () * expert_attention_mask , dim = 0 ) / torch . sum (
114- expert_attention_mask , dim = 0
115- )
119+ tokens_per_expert = torch .sum (
120+ expert_mask . float () * expert_attention_mask , dim = 0
121+ ) / torch . sum ( expert_attention_mask , dim = 0 )
116122
117123 # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
118124 router_per_expert_attention_mask = (
@@ -123,9 +129,9 @@ def load_balancing_loss_func(
123129 )
124130
125131 # Compute the average probability of routing to these experts
126- router_prob_per_expert = torch .sum (routing_weights * router_per_expert_attention_mask , dim = 0 ) / torch . sum (
127- router_per_expert_attention_mask , dim = 0
128- )
132+ router_prob_per_expert = torch .sum (
133+ routing_weights * router_per_expert_attention_mask , dim = 0
134+ ) / torch . sum ( router_per_expert_attention_mask , dim = 0 )
129135
130136 overall_loss = torch .sum (tokens_per_expert * router_prob_per_expert .unsqueeze (0 ))
131137 return overall_loss * num_experts
@@ -144,7 +150,9 @@ def __init__(self, config: MixtralConfig):
144150 self .act_fn = ACT2FN [config .hidden_act ]
145151
146152 def forward (self , hidden_states ):
147- current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (hidden_states )
153+ current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (
154+ hidden_states
155+ )
148156 current_hidden_states = self .w2 (current_hidden_states )
149157 return current_hidden_states
150158
@@ -162,25 +170,38 @@ def __init__(self, config: MixtralConfig):
162170 self .append (MixtralMLP (config ))
163171
164172 def forward (
165- self , hidden_states : torch .Tensor , top_k_index : torch .Tensor , top_k_weights : torch .Tensor
173+ self ,
174+ hidden_states : torch .Tensor ,
175+ top_k_index : torch .Tensor ,
176+ top_k_weights : torch .Tensor ,
166177 ) -> torch .Tensor :
167178 """
168179 Args:
169- hidden_states: (batch_size * sequence_length, hidden_dim)
170- selected_experts: (batch_size * sequence_length, top_k)
171- routing_weights: (batch_size * sequence_length, top_k)
172- Returns:
173- (batch_size * sequence_length, hidden_dim)
180+ hidden_states (`torch.Tensor`):
181+ Input tensor.
182+ top_k_index (`torch.Tensor`):
183+ The indices of the top-k selected experts.
184+ top_k_weights (`torch.Tensor`):
185+ The weights corresponding to the top-k selected experts.
174186 """
187+
175188 final_hidden_states = torch .zeros_like (hidden_states )
176- expert_mask = torch .nn .functional .one_hot (top_k_index , num_classes = self .num_experts ).permute (2 , 1 , 0 )
189+ expert_mask = torch .nn .functional .one_hot (
190+ top_k_index , num_classes = self .num_experts
191+ ).permute (2 , 1 , 0 )
177192
178193 expert_hit = torch .greater (expert_mask .sum (dim = (- 1 , - 2 )), 0 ).nonzero ()
179194 for expert_idx in expert_hit :
180195 idx , top_x = torch .where (expert_mask [expert_idx ].squeeze (0 ))
181- current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_states .shape [- 1 ])
182- current_hidden_states = self [expert_idx ](current_state ) * top_k_weights [top_x , idx , None ]
183- final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (hidden_states .dtype ))
196+ current_state = hidden_states [None , top_x ].reshape (
197+ - 1 , hidden_states .shape [- 1 ]
198+ )
199+ current_hidden_states = (
200+ self [expert_idx ](current_state ) * top_k_weights [top_x , idx , None ]
201+ )
202+ final_hidden_states .index_add_ (
203+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
204+ )
184205 return final_hidden_states
185206
186207
@@ -201,11 +222,15 @@ def route_tokens_to_experts(self, router_logits):
201222 def forward (self , hidden_states : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
202223 batch_size , sequence_length , hidden_dim = hidden_states .shape
203224 if self .training and self .jitter_noise > 0 :
204- hidden_states *= torch .empty_like (hidden_states ).uniform_ (1.0 - self .jitter_noise , 1.0 + self .jitter_noise )
225+ hidden_states *= torch .empty_like (hidden_states ).uniform_ (
226+ 1.0 - self .jitter_noise , 1.0 + self .jitter_noise
227+ )
205228 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
206229 router_logits = self .gate (hidden_states )
207230 top_k_index , top_k_weights = self .route_tokens_to_experts (router_logits )
208- hidden_states = self .experts (hidden_states , top_k_index , top_k_weights .to (hidden_states .dtype ))
231+ hidden_states = self .experts (
232+ hidden_states , top_k_index , top_k_weights .to (hidden_states .dtype )
233+ )
209234 hidden_states = hidden_states .reshape (batch_size , sequence_length , hidden_dim )
210235 return hidden_states
211236
@@ -230,8 +255,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
230255 self .self_attn = MixtralAttention (config , layer_idx )
231256
232257 self .block_sparse_moe = MixtralSparseMoeBlock (config )
233- self .input_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
234- self .post_attention_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
258+ self .input_layernorm = MixtralRMSNorm (
259+ config .hidden_size , eps = config .rms_norm_eps
260+ )
261+ self .post_attention_layernorm = MixtralRMSNorm (
262+ config .hidden_size , eps = config .rms_norm_eps
263+ )
235264
236265 def forward (
237266 self ,
@@ -265,7 +294,9 @@ def forward(
265294class MixtralPreTrainedModel (MistralPreTrainedModel ):
266295 _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
267296 _can_record_outputs = {
268- "router_logits" : OutputRecorder (nn .Linear , layer_name = "block_sparse_moe.gate" , index = 0 ),
297+ "router_logits" : OutputRecorder (
298+ nn .Linear , layer_name = "block_sparse_moe.gate" , index = 0
299+ ),
269300 "hidden_states" : MixtralDecoderLayer ,
270301 "attentions" : MixtralAttention ,
271302 }
@@ -284,7 +315,9 @@ def forward(
284315 ** kwargs : Unpack [TransformersKwargs ],
285316 ) -> MoeModelOutputWithPast :
286317 if (input_ids is None ) ^ (inputs_embeds is not None ):
287- raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
318+ raise ValueError (
319+ "You must specify exactly one of input_ids or inputs_embeds"
320+ )
288321
289322 if use_cache and past_key_values is None :
290323 past_key_values = DynamicCache (config = self .config )
@@ -293,14 +326,22 @@ def forward(
293326 inputs_embeds = self .embed_tokens (input_ids )
294327
295328 if cache_position is None :
296- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
329+ past_seen_tokens = (
330+ past_key_values .get_seq_length () if past_key_values is not None else 0
331+ )
297332 cache_position = torch .arange (
298- past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
333+ past_seen_tokens ,
334+ past_seen_tokens + inputs_embeds .shape [1 ],
335+ device = inputs_embeds .device ,
299336 )
300337 if position_ids is None :
301338 position_ids = cache_position .unsqueeze (0 )
302339
303- mask_function = create_causal_mask if self .config .sliding_window is None else create_sliding_window_causal_mask
340+ mask_function = (
341+ create_causal_mask
342+ if self .config .sliding_window is None
343+ else create_sliding_window_causal_mask
344+ )
304345 causal_mask = mask_function (
305346 config = self .config ,
306347 input_embeds = inputs_embeds ,
@@ -381,7 +422,9 @@ def forward(
381422 ```"""
382423
383424 output_router_logits = (
384- output_router_logits if output_router_logits is not None else self .config .output_router_logits
425+ output_router_logits
426+ if output_router_logits is not None
427+ else self .config .output_router_logits
385428 )
386429
387430 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -399,7 +442,11 @@ def forward(
399442
400443 hidden_states = outputs .last_hidden_state
401444 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
402- slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
445+ slice_indices = (
446+ slice (- logits_to_keep , None )
447+ if isinstance (logits_to_keep , int )
448+ else logits_to_keep
449+ )
403450 logits = self .lm_head (hidden_states [:, slice_indices , :])
404451
405452 loss = None
@@ -415,7 +462,9 @@ def forward(
415462 attention_mask ,
416463 )
417464 if labels is not None :
418- loss += self .router_aux_loss_coef * aux_loss .to (loss .device ) # make sure to reside in the same device
465+ loss += self .router_aux_loss_coef * aux_loss .to (
466+ loss .device
467+ ) # make sure to reside in the same device
419468
420469 return MoeCausalLMOutputWithPast (
421470 loss = loss ,
0 commit comments