Skip to content

Conversation

@juliendenize
Copy link
Contributor

What does this PR do?

Add llama 4 scaling for long context to Mistral models.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten @ArthurZucker @zucchini-nlp

Comment on lines 156 to 157
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
llama_4_scaling: Optional[LLama4Scaling] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have it in rope_parameters instead? Looks more intuitive to save all RoPE related args in one place. Since there are already parameters for RoPE scaling in the dict, we can add these as llama4_{param_name}

vocab_size=32000,
head_dim=64,
hidden_act="gelu",
hidden_act="silu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a bit breaking for configs that were saved without hidden_act and therefore defaulted to GeLU prev

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was broken before actually, it is a fix here.

image_processor=image_processor,
image_token="[IMG]",
patch_size=patch_size,
chat_template=chat_template,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model has no chat template anymore? 🥲

Copy link
Contributor Author

@juliendenize juliendenize Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes no sense anymore to put one here unfortunately, depending on the models the chat template looks very different (tokenizer version, Thinking or not, ...). So having a default one is arguably worse than none at all imo.

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if self.config.rope_parameters.llama_4_scaling_beta is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rope_parameters is a simple dict so we better safe-get rope_parameters.get(llama_4_scaling_beta')

@zucchini-nlp
Copy link
Member

Look good, let's just address one comment and fix CI so it's ✅

@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: mistral, mistral3

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We will need to add a new model for this change to get through, as we never change old models to add new features 🤗

The best way is to write a simple modular_new_model.py with something like this:

...

class MyModelAttention(MistralAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states = query_states * self._get_llama4_attn_scale(cache_position).to(query_states.dtype)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=getattr(self.config, "sliding_window", None),  # main diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

the best would be to modify the RopeEmbedding to do the scaling in there, this way you don't have to change the forward pass!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants