@@ -83,7 +83,9 @@ def load_balancing_loss_func(
8383
8484 if isinstance (gate_logits , tuple ):
8585 compute_device = gate_logits [0 ].device
86- concatenated_gate_logits = torch .cat ([layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0 )
86+ concatenated_gate_logits = torch .cat (
87+ [layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0
88+ )
8789
8890 routing_weights = torch .nn .functional .softmax (concatenated_gate_logits , dim = - 1 )
8991
@@ -99,20 +101,24 @@ def load_balancing_loss_func(
99101 router_prob_per_expert = torch .mean (routing_weights , dim = 0 )
100102 else :
101103 batch_size , sequence_length = attention_mask .shape
102- num_hidden_layers = concatenated_gate_logits .shape [0 ] // (batch_size * sequence_length )
104+ num_hidden_layers = concatenated_gate_logits .shape [0 ] // (
105+ batch_size * sequence_length
106+ )
103107
104108 # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
105109 expert_attention_mask = (
106110 attention_mask [None , :, :, None , None ]
107- .expand ((num_hidden_layers , batch_size , sequence_length , top_k , num_experts ))
111+ .expand (
112+ (num_hidden_layers , batch_size , sequence_length , top_k , num_experts )
113+ )
108114 .reshape (- 1 , top_k , num_experts )
109115 .to (compute_device )
110116 )
111117
112118 # Compute the percentage of tokens routed to each experts
113- tokens_per_expert = torch .sum (expert_mask . float () * expert_attention_mask , dim = 0 ) / torch . sum (
114- expert_attention_mask , dim = 0
115- )
119+ tokens_per_expert = torch .sum (
120+ expert_mask . float () * expert_attention_mask , dim = 0
121+ ) / torch . sum ( expert_attention_mask , dim = 0 )
116122
117123 # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
118124 router_per_expert_attention_mask = (
@@ -123,9 +129,9 @@ def load_balancing_loss_func(
123129 )
124130
125131 # Compute the average probability of routing to these experts
126- router_prob_per_expert = torch .sum (routing_weights * router_per_expert_attention_mask , dim = 0 ) / torch . sum (
127- router_per_expert_attention_mask , dim = 0
128- )
132+ router_prob_per_expert = torch .sum (
133+ routing_weights * router_per_expert_attention_mask , dim = 0
134+ ) / torch . sum ( router_per_expert_attention_mask , dim = 0 )
129135
130136 overall_loss = torch .sum (tokens_per_expert * router_prob_per_expert .unsqueeze (0 ))
131137 return overall_loss * num_experts
@@ -144,7 +150,9 @@ def __init__(self, config: MixtralConfig):
144150 self .act_fn = ACT2FN [config .hidden_act ]
145151
146152 def forward (self , hidden_states ):
147- current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (hidden_states )
153+ current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (
154+ hidden_states
155+ )
148156 current_hidden_states = self .w2 (current_hidden_states )
149157 return current_hidden_states
150158
@@ -161,26 +169,34 @@ def __init__(self, config: MixtralConfig):
161169 for _ in range (self .num_experts ):
162170 self .append (MixtralMLP (config ))
163171
164- def forward (
165- self , hidden_states : torch .Tensor , top_k_index : torch .Tensor , top_k_weights : torch .Tensor
166- ) -> torch .Tensor :
172+ def forward (self , hidden_states : torch .Tensor , top_k_index : torch .Tensor , top_k_weights : torch .Tensor ) -> torch .Tensor :
167173 """
168174 Args:
169- hidden_states: (batch_size * sequence_length, hidden_dim)
170- selected_experts: (batch_size * sequence_length, top_k)
171- routing_weights: (batch_size * sequence_length, top_k)
172- Returns:
173- (batch_size * sequence_length, hidden_dim)
175+ hidden_states (`torch.Tensor`):
176+ Input tensor.
177+ top_k_index (`torch.Tensor`):
178+ The indices of the top-k selected experts.
179+ top_k_weights (`torch.Tensor`):
180+ The weights corresponding to the top-k selected experts.
174181 """
182+
175183 final_hidden_states = torch .zeros_like (hidden_states )
176- expert_mask = torch .nn .functional .one_hot (top_k_index , num_classes = self .num_experts ).permute (2 , 1 , 0 )
184+ expert_mask = torch .nn .functional .one_hot (
185+ top_k_index , num_classes = self .num_experts
186+ ).permute (2 , 1 , 0 )
177187
178188 expert_hit = torch .greater (expert_mask .sum (dim = (- 1 , - 2 )), 0 ).nonzero ()
179189 for expert_idx in expert_hit :
180190 idx , top_x = torch .where (expert_mask [expert_idx ].squeeze (0 ))
181- current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_states .shape [- 1 ])
182- current_hidden_states = self [expert_idx ](current_state ) * top_k_weights [top_x , idx , None ]
183- final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (hidden_states .dtype ))
191+ current_state = hidden_states [None , top_x ].reshape (
192+ - 1 , hidden_states .shape [- 1 ]
193+ )
194+ current_hidden_states = (
195+ self [expert_idx ](current_state ) * top_k_weights [top_x , idx , None ]
196+ )
197+ final_hidden_states .index_add_ (
198+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
199+ )
184200 return final_hidden_states
185201
186202
@@ -201,11 +217,15 @@ def route_tokens_to_experts(self, router_logits):
201217 def forward (self , hidden_states : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
202218 batch_size , sequence_length , hidden_dim = hidden_states .shape
203219 if self .training and self .jitter_noise > 0 :
204- hidden_states *= torch .empty_like (hidden_states ).uniform_ (1.0 - self .jitter_noise , 1.0 + self .jitter_noise )
220+ hidden_states *= torch .empty_like (hidden_states ).uniform_ (
221+ 1.0 - self .jitter_noise , 1.0 + self .jitter_noise
222+ )
205223 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
206224 router_logits = self .gate (hidden_states )
207225 top_k_index , top_k_weights = self .route_tokens_to_experts (router_logits )
208- hidden_states = self .experts (hidden_states , top_k_index , top_k_weights .to (hidden_states .dtype ))
226+ hidden_states = self .experts (
227+ hidden_states , top_k_index , top_k_weights .to (hidden_states .dtype )
228+ )
209229 hidden_states = hidden_states .reshape (batch_size , sequence_length , hidden_dim )
210230 return hidden_states
211231
@@ -230,8 +250,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
230250 self .self_attn = MixtralAttention (config , layer_idx )
231251
232252 self .block_sparse_moe = MixtralSparseMoeBlock (config )
233- self .input_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
234- self .post_attention_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
253+ self .input_layernorm = MixtralRMSNorm (
254+ config .hidden_size , eps = config .rms_norm_eps
255+ )
256+ self .post_attention_layernorm = MixtralRMSNorm (
257+ config .hidden_size , eps = config .rms_norm_eps
258+ )
235259
236260 def forward (
237261 self ,
@@ -265,7 +289,9 @@ def forward(
265289class MixtralPreTrainedModel (MistralPreTrainedModel ):
266290 _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
267291 _can_record_outputs = {
268- "router_logits" : OutputRecorder (nn .Linear , layer_name = "block_sparse_moe.gate" , index = 0 ),
292+ "router_logits" : OutputRecorder (
293+ nn .Linear , layer_name = "block_sparse_moe.gate" , index = 0
294+ ),
269295 "hidden_states" : MixtralDecoderLayer ,
270296 "attentions" : MixtralAttention ,
271297 }
@@ -284,7 +310,9 @@ def forward(
284310 ** kwargs : Unpack [TransformersKwargs ],
285311 ) -> MoeModelOutputWithPast :
286312 if (input_ids is None ) ^ (inputs_embeds is not None ):
287- raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
313+ raise ValueError (
314+ "You must specify exactly one of input_ids or inputs_embeds"
315+ )
288316
289317 if use_cache and past_key_values is None :
290318 past_key_values = DynamicCache (config = self .config )
@@ -293,14 +321,22 @@ def forward(
293321 inputs_embeds = self .embed_tokens (input_ids )
294322
295323 if cache_position is None :
296- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
324+ past_seen_tokens = (
325+ past_key_values .get_seq_length () if past_key_values is not None else 0
326+ )
297327 cache_position = torch .arange (
298- past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
328+ past_seen_tokens ,
329+ past_seen_tokens + inputs_embeds .shape [1 ],
330+ device = inputs_embeds .device ,
299331 )
300332 if position_ids is None :
301333 position_ids = cache_position .unsqueeze (0 )
302334
303- mask_function = create_causal_mask if self .config .sliding_window is None else create_sliding_window_causal_mask
335+ mask_function = (
336+ create_causal_mask
337+ if self .config .sliding_window is None
338+ else create_sliding_window_causal_mask
339+ )
304340 causal_mask = mask_function (
305341 config = self .config ,
306342 input_embeds = inputs_embeds ,
@@ -381,7 +417,9 @@ def forward(
381417 ```"""
382418
383419 output_router_logits = (
384- output_router_logits if output_router_logits is not None else self .config .output_router_logits
420+ output_router_logits
421+ if output_router_logits is not None
422+ else self .config .output_router_logits
385423 )
386424
387425 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -399,7 +437,11 @@ def forward(
399437
400438 hidden_states = outputs .last_hidden_state
401439 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
402- slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
440+ slice_indices = (
441+ slice (- logits_to_keep , None )
442+ if isinstance (logits_to_keep , int )
443+ else logits_to_keep
444+ )
403445 logits = self .lm_head (hidden_states [:, slice_indices , :])
404446
405447 loss = None
@@ -415,7 +457,9 @@ def forward(
415457 attention_mask ,
416458 )
417459 if labels is not None :
418- loss += self .router_aux_loss_coef * aux_loss .to (loss .device ) # make sure to reside in the same device
460+ loss += self .router_aux_loss_coef * aux_loss .to (
461+ loss .device
462+ ) # make sure to reside in the same device
419463
420464 return MoeCausalLMOutputWithPast (
421465 loss = loss ,
0 commit comments