Skip to content

Commit 8b06044

Browse files
Fix: Mixtral docstring inconsistency for top_k_index/weights
1 parent dd8f231 commit 8b06044

File tree

1 file changed

+81
-32
lines changed

1 file changed

+81
-32
lines changed

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def load_balancing_loss_func(
8383

8484
if isinstance(gate_logits, tuple):
8585
compute_device = gate_logits[0].device
86-
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
86+
concatenated_gate_logits = torch.cat(
87+
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
88+
)
8789

8890
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
8991

@@ -99,20 +101,24 @@ def load_balancing_loss_func(
99101
router_prob_per_expert = torch.mean(routing_weights, dim=0)
100102
else:
101103
batch_size, sequence_length = attention_mask.shape
102-
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
104+
num_hidden_layers = concatenated_gate_logits.shape[0] // (
105+
batch_size * sequence_length
106+
)
103107

104108
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
105109
expert_attention_mask = (
106110
attention_mask[None, :, :, None, None]
107-
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
111+
.expand(
112+
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
113+
)
108114
.reshape(-1, top_k, num_experts)
109115
.to(compute_device)
110116
)
111117

112118
# Compute the percentage of tokens routed to each experts
113-
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
114-
expert_attention_mask, dim=0
115-
)
119+
tokens_per_expert = torch.sum(
120+
expert_mask.float() * expert_attention_mask, dim=0
121+
) / torch.sum(expert_attention_mask, dim=0)
116122

117123
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
118124
router_per_expert_attention_mask = (
@@ -123,9 +129,9 @@ def load_balancing_loss_func(
123129
)
124130

125131
# Compute the average probability of routing to these experts
126-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
127-
router_per_expert_attention_mask, dim=0
128-
)
132+
router_prob_per_expert = torch.sum(
133+
routing_weights * router_per_expert_attention_mask, dim=0
134+
) / torch.sum(router_per_expert_attention_mask, dim=0)
129135

130136
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
131137
return overall_loss * num_experts
@@ -144,7 +150,9 @@ def __init__(self, config: MixtralConfig):
144150
self.act_fn = ACT2FN[config.hidden_act]
145151

146152
def forward(self, hidden_states):
147-
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
153+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
154+
hidden_states
155+
)
148156
current_hidden_states = self.w2(current_hidden_states)
149157
return current_hidden_states
150158

@@ -162,25 +170,38 @@ def __init__(self, config: MixtralConfig):
162170
self.append(MixtralMLP(config))
163171

164172
def forward(
165-
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
173+
self,
174+
hidden_states: torch.Tensor,
175+
top_k_index: torch.Tensor,
176+
top_k_weights: torch.Tensor,
166177
) -> torch.Tensor:
167178
"""
168179
Args:
169-
hidden_states: (batch_size * sequence_length, hidden_dim)
170-
selected_experts: (batch_size * sequence_length, top_k)
171-
routing_weights: (batch_size * sequence_length, top_k)
172-
Returns:
173-
(batch_size * sequence_length, hidden_dim)
180+
hidden_states (`torch.Tensor`):
181+
Input tensor.
182+
top_k_index (`torch.Tensor`):
183+
The indices of the top-k selected experts.
184+
top_k_weights (`torch.Tensor`):
185+
The weights corresponding to the top-k selected experts.
174186
"""
187+
175188
final_hidden_states = torch.zeros_like(hidden_states)
176-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
189+
expert_mask = torch.nn.functional.one_hot(
190+
top_k_index, num_classes=self.num_experts
191+
).permute(2, 1, 0)
177192

178193
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
179194
for expert_idx in expert_hit:
180195
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
181-
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
182-
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
183-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
196+
current_state = hidden_states[None, top_x].reshape(
197+
-1, hidden_states.shape[-1]
198+
)
199+
current_hidden_states = (
200+
self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
201+
)
202+
final_hidden_states.index_add_(
203+
0, top_x, current_hidden_states.to(hidden_states.dtype)
204+
)
184205
return final_hidden_states
185206

186207

@@ -201,11 +222,15 @@ def route_tokens_to_experts(self, router_logits):
201222
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
202223
batch_size, sequence_length, hidden_dim = hidden_states.shape
203224
if self.training and self.jitter_noise > 0:
204-
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
225+
hidden_states *= torch.empty_like(hidden_states).uniform_(
226+
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
227+
)
205228
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
206229
router_logits = self.gate(hidden_states)
207230
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
208-
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
231+
hidden_states = self.experts(
232+
hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)
233+
)
209234
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
210235
return hidden_states
211236

@@ -230,8 +255,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
230255
self.self_attn = MixtralAttention(config, layer_idx)
231256

232257
self.block_sparse_moe = MixtralSparseMoeBlock(config)
233-
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234-
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
258+
self.input_layernorm = MixtralRMSNorm(
259+
config.hidden_size, eps=config.rms_norm_eps
260+
)
261+
self.post_attention_layernorm = MixtralRMSNorm(
262+
config.hidden_size, eps=config.rms_norm_eps
263+
)
235264

236265
def forward(
237266
self,
@@ -265,7 +294,9 @@ def forward(
265294
class MixtralPreTrainedModel(MistralPreTrainedModel):
266295
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
267296
_can_record_outputs = {
268-
"router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
297+
"router_logits": OutputRecorder(
298+
nn.Linear, layer_name="block_sparse_moe.gate", index=0
299+
),
269300
"hidden_states": MixtralDecoderLayer,
270301
"attentions": MixtralAttention,
271302
}
@@ -284,7 +315,9 @@ def forward(
284315
**kwargs: Unpack[TransformersKwargs],
285316
) -> MoeModelOutputWithPast:
286317
if (input_ids is None) ^ (inputs_embeds is not None):
287-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
318+
raise ValueError(
319+
"You must specify exactly one of input_ids or inputs_embeds"
320+
)
288321

289322
if use_cache and past_key_values is None:
290323
past_key_values = DynamicCache(config=self.config)
@@ -293,14 +326,22 @@ def forward(
293326
inputs_embeds = self.embed_tokens(input_ids)
294327

295328
if cache_position is None:
296-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
329+
past_seen_tokens = (
330+
past_key_values.get_seq_length() if past_key_values is not None else 0
331+
)
297332
cache_position = torch.arange(
298-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
333+
past_seen_tokens,
334+
past_seen_tokens + inputs_embeds.shape[1],
335+
device=inputs_embeds.device,
299336
)
300337
if position_ids is None:
301338
position_ids = cache_position.unsqueeze(0)
302339

303-
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
340+
mask_function = (
341+
create_causal_mask
342+
if self.config.sliding_window is None
343+
else create_sliding_window_causal_mask
344+
)
304345
causal_mask = mask_function(
305346
config=self.config,
306347
input_embeds=inputs_embeds,
@@ -381,7 +422,9 @@ def forward(
381422
```"""
382423

383424
output_router_logits = (
384-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
425+
output_router_logits
426+
if output_router_logits is not None
427+
else self.config.output_router_logits
385428
)
386429

387430
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -399,7 +442,11 @@ def forward(
399442

400443
hidden_states = outputs.last_hidden_state
401444
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
402-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
445+
slice_indices = (
446+
slice(-logits_to_keep, None)
447+
if isinstance(logits_to_keep, int)
448+
else logits_to_keep
449+
)
403450
logits = self.lm_head(hidden_states[:, slice_indices, :])
404451

405452
loss = None
@@ -415,7 +462,9 @@ def forward(
415462
attention_mask,
416463
)
417464
if labels is not None:
418-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
465+
loss += self.router_aux_loss_coef * aux_loss.to(
466+
loss.device
467+
) # make sure to reside in the same device
419468

420469
return MoeCausalLMOutputWithPast(
421470
loss=loss,

0 commit comments

Comments
 (0)