-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Add llama 4 scaling #42045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add llama 4 scaling #42045
Conversation
| rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, | ||
| llama_4_scaling: Optional[LLama4Scaling] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have it in rope_parameters instead? Looks more intuitive to save all RoPE related args in one place. Since there are already parameters for RoPE scaling in the dict, we can add these as llama4_{param_name}
| vocab_size=32000, | ||
| head_dim=64, | ||
| hidden_act="gelu", | ||
| hidden_act="silu", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be a bit breaking for configs that were saved without hidden_act and therefore defaulted to GeLU prev
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was broken before actually, it is a fix here.
| image_processor=image_processor, | ||
| image_token="[IMG]", | ||
| patch_size=patch_size, | ||
| chat_template=chat_template, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model has no chat template anymore? 🥲
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes no sense anymore to put one here unfortunately, depending on the models the chat template looks very different (tokenizer version, Thinking or not, ...). So having a default one is arguably worse than none at all imo.
| cos, sin = position_embeddings | ||
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
|
||
| if self.config.rope_parameters.llama_4_scaling_beta is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rope_parameters is a simple dict so we better safe-get rope_parameters.get(llama_4_scaling_beta')
|
Look good, let's just address one comment and fix CI so it's ✅ |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mistral, mistral3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! We will need to add a new model for this change to get through, as we never change old models to add new features 🤗
The best way is to write a simple modular_new_model.py with something like this:
...
class MyModelAttention(MistralAttention):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states = query_states * self._get_llama4_attn_scale(cache_position).to(query_states.dtype)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weightsthe best would be to modify the RopeEmbedding to do the scaling in there, this way you don't have to change the forward pass!
What does this PR do?
Add llama 4 scaling for long context to Mistral models.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten @ArthurZucker @zucchini-nlp