Skip to content

Commit 70b756c

Browse files
Fix: Mixtral docstring inconsistency for top_k_index/weights
1 parent dd8f231 commit 70b756c

File tree

1 file changed

+78
-34
lines changed

1 file changed

+78
-34
lines changed

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def load_balancing_loss_func(
8383

8484
if isinstance(gate_logits, tuple):
8585
compute_device = gate_logits[0].device
86-
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
86+
concatenated_gate_logits = torch.cat(
87+
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
88+
)
8789

8890
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
8991

@@ -99,20 +101,24 @@ def load_balancing_loss_func(
99101
router_prob_per_expert = torch.mean(routing_weights, dim=0)
100102
else:
101103
batch_size, sequence_length = attention_mask.shape
102-
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
104+
num_hidden_layers = concatenated_gate_logits.shape[0] // (
105+
batch_size * sequence_length
106+
)
103107

104108
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
105109
expert_attention_mask = (
106110
attention_mask[None, :, :, None, None]
107-
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
111+
.expand(
112+
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
113+
)
108114
.reshape(-1, top_k, num_experts)
109115
.to(compute_device)
110116
)
111117

112118
# Compute the percentage of tokens routed to each experts
113-
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
114-
expert_attention_mask, dim=0
115-
)
119+
tokens_per_expert = torch.sum(
120+
expert_mask.float() * expert_attention_mask, dim=0
121+
) / torch.sum(expert_attention_mask, dim=0)
116122

117123
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
118124
router_per_expert_attention_mask = (
@@ -123,9 +129,9 @@ def load_balancing_loss_func(
123129
)
124130

125131
# Compute the average probability of routing to these experts
126-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
127-
router_per_expert_attention_mask, dim=0
128-
)
132+
router_prob_per_expert = torch.sum(
133+
routing_weights * router_per_expert_attention_mask, dim=0
134+
) / torch.sum(router_per_expert_attention_mask, dim=0)
129135

130136
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
131137
return overall_loss * num_experts
@@ -144,7 +150,9 @@ def __init__(self, config: MixtralConfig):
144150
self.act_fn = ACT2FN[config.hidden_act]
145151

146152
def forward(self, hidden_states):
147-
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
153+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
154+
hidden_states
155+
)
148156
current_hidden_states = self.w2(current_hidden_states)
149157
return current_hidden_states
150158

@@ -161,26 +169,34 @@ def __init__(self, config: MixtralConfig):
161169
for _ in range(self.num_experts):
162170
self.append(MixtralMLP(config))
163171

164-
def forward(
165-
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
166-
) -> torch.Tensor:
172+
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor) -> torch.Tensor:
167173
"""
168174
Args:
169-
hidden_states: (batch_size * sequence_length, hidden_dim)
170-
selected_experts: (batch_size * sequence_length, top_k)
171-
routing_weights: (batch_size * sequence_length, top_k)
172-
Returns:
173-
(batch_size * sequence_length, hidden_dim)
175+
hidden_states (`torch.Tensor`):
176+
Input tensor.
177+
top_k_index (`torch.Tensor`):
178+
The indices of the top-k selected experts.
179+
top_k_weights (`torch.Tensor`):
180+
The weights corresponding to the top-k selected experts.
174181
"""
182+
175183
final_hidden_states = torch.zeros_like(hidden_states)
176-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
184+
expert_mask = torch.nn.functional.one_hot(
185+
top_k_index, num_classes=self.num_experts
186+
).permute(2, 1, 0)
177187

178188
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
179189
for expert_idx in expert_hit:
180190
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
181-
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
182-
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
183-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
191+
current_state = hidden_states[None, top_x].reshape(
192+
-1, hidden_states.shape[-1]
193+
)
194+
current_hidden_states = (
195+
self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
196+
)
197+
final_hidden_states.index_add_(
198+
0, top_x, current_hidden_states.to(hidden_states.dtype)
199+
)
184200
return final_hidden_states
185201

186202

@@ -201,11 +217,15 @@ def route_tokens_to_experts(self, router_logits):
201217
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
202218
batch_size, sequence_length, hidden_dim = hidden_states.shape
203219
if self.training and self.jitter_noise > 0:
204-
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
220+
hidden_states *= torch.empty_like(hidden_states).uniform_(
221+
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
222+
)
205223
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
206224
router_logits = self.gate(hidden_states)
207225
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
208-
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
226+
hidden_states = self.experts(
227+
hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)
228+
)
209229
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
210230
return hidden_states
211231

@@ -230,8 +250,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
230250
self.self_attn = MixtralAttention(config, layer_idx)
231251

232252
self.block_sparse_moe = MixtralSparseMoeBlock(config)
233-
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234-
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
253+
self.input_layernorm = MixtralRMSNorm(
254+
config.hidden_size, eps=config.rms_norm_eps
255+
)
256+
self.post_attention_layernorm = MixtralRMSNorm(
257+
config.hidden_size, eps=config.rms_norm_eps
258+
)
235259

236260
def forward(
237261
self,
@@ -265,7 +289,9 @@ def forward(
265289
class MixtralPreTrainedModel(MistralPreTrainedModel):
266290
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
267291
_can_record_outputs = {
268-
"router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
292+
"router_logits": OutputRecorder(
293+
nn.Linear, layer_name="block_sparse_moe.gate", index=0
294+
),
269295
"hidden_states": MixtralDecoderLayer,
270296
"attentions": MixtralAttention,
271297
}
@@ -284,7 +310,9 @@ def forward(
284310
**kwargs: Unpack[TransformersKwargs],
285311
) -> MoeModelOutputWithPast:
286312
if (input_ids is None) ^ (inputs_embeds is not None):
287-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
313+
raise ValueError(
314+
"You must specify exactly one of input_ids or inputs_embeds"
315+
)
288316

289317
if use_cache and past_key_values is None:
290318
past_key_values = DynamicCache(config=self.config)
@@ -293,14 +321,22 @@ def forward(
293321
inputs_embeds = self.embed_tokens(input_ids)
294322

295323
if cache_position is None:
296-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
324+
past_seen_tokens = (
325+
past_key_values.get_seq_length() if past_key_values is not None else 0
326+
)
297327
cache_position = torch.arange(
298-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
328+
past_seen_tokens,
329+
past_seen_tokens + inputs_embeds.shape[1],
330+
device=inputs_embeds.device,
299331
)
300332
if position_ids is None:
301333
position_ids = cache_position.unsqueeze(0)
302334

303-
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
335+
mask_function = (
336+
create_causal_mask
337+
if self.config.sliding_window is None
338+
else create_sliding_window_causal_mask
339+
)
304340
causal_mask = mask_function(
305341
config=self.config,
306342
input_embeds=inputs_embeds,
@@ -381,7 +417,9 @@ def forward(
381417
```"""
382418

383419
output_router_logits = (
384-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
420+
output_router_logits
421+
if output_router_logits is not None
422+
else self.config.output_router_logits
385423
)
386424

387425
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -399,7 +437,11 @@ def forward(
399437

400438
hidden_states = outputs.last_hidden_state
401439
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
402-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
440+
slice_indices = (
441+
slice(-logits_to_keep, None)
442+
if isinstance(logits_to_keep, int)
443+
else logits_to_keep
444+
)
403445
logits = self.lm_head(hidden_states[:, slice_indices, :])
404446

405447
loss = None
@@ -415,7 +457,9 @@ def forward(
415457
attention_mask,
416458
)
417459
if labels is not None:
418-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
460+
loss += self.router_aux_loss_coef * aux_loss.to(
461+
loss.device
462+
) # make sure to reside in the same device
419463

420464
return MoeCausalLMOutputWithPast(
421465
loss=loss,

0 commit comments

Comments
 (0)