Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 81 additions & 32 deletions src/transformers/models/mixtral/modular_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def load_balancing_loss_func(

if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

Expand All @@ -99,20 +101,24 @@ def load_balancing_loss_func(
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
num_hidden_layers = concatenated_gate_logits.shape[0] // (
batch_size * sequence_length
)

# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.expand(
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
)
.reshape(-1, top_k, num_experts)
.to(compute_device)
)

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
expert_attention_mask, dim=0
)
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)

# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
Expand All @@ -123,9 +129,9 @@ def load_balancing_loss_func(
)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
router_per_expert_attention_mask, dim=0
)
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
Expand All @@ -144,7 +150,9 @@ def __init__(self, config: MixtralConfig):
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states

Expand All @@ -162,25 +170,38 @@ def __init__(self, config: MixtralConfig):
self.append(MixtralMLP(config))

def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
hidden_states (`torch.Tensor`):
Input tensor.
top_k_index (`torch.Tensor`):
The indices of the top-k selected experts.
top_k_weights (`torch.Tensor`):
The weights corresponding to the top-k selected experts.
"""

final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_mask = torch.nn.functional.one_hot(
top_k_index, num_classes=self.num_experts
).permute(2, 1, 0)

expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
current_state = hidden_states[None, top_x].reshape(
-1, hidden_states.shape[-1]
)
current_hidden_states = (
self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
)
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
return final_hidden_states


Expand All @@ -201,11 +222,15 @@ def route_tokens_to_experts(self, router_logits):
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
hidden_states = self.experts(
hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)
)
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return hidden_states

Expand All @@ -230,8 +255,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
self.self_attn = MixtralAttention(config, layer_idx)

self.block_sparse_moe = MixtralSparseMoeBlock(config)
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = MixtralRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MixtralRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)

def forward(
self,
Expand Down Expand Up @@ -265,7 +294,9 @@ def forward(
class MixtralPreTrainedModel(MistralPreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_can_record_outputs = {
"router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
"router_logits": OutputRecorder(
nn.Linear, layer_name="block_sparse_moe.gate", index=0
),
"hidden_states": MixtralDecoderLayer,
"attentions": MixtralAttention,
}
Expand All @@ -284,7 +315,9 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)

if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
Expand All @@ -293,14 +326,22 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
mask_function = (
create_causal_mask
if self.config.sliding_window is None
else create_sliding_window_causal_mask
)
causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
Expand Down Expand Up @@ -381,7 +422,9 @@ def forward(
```"""

output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
Expand All @@ -399,7 +442,11 @@ def forward(

hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
Expand All @@ -415,7 +462,9 @@ def forward(
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device

return MoeCausalLMOutputWithPast(
loss=loss,
Expand Down