Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 61 additions & 64 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ def dynamic_rope_update(rope_forward):
def longrope_frequency_update(self, position_ids, device, layer_type=None):
"""Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
seq_len = torch.max(position_ids) + 1
original_max_position_embeddings = getattr(
self.config, "original_max_position_embeddings", self.config.max_position_embeddings
)

if layer_type is None:
rope_type = self.rope_type
original_inv_freq = self.original_inv_freq
prefix = ""
original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"]
Copy link
Member Author

@zucchini-nlp zucchini-nlp Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe to assume it already exists. We move original_max_position_embeddings to its correct location at config init time

else:
rope_type = self.rope_type[layer_type]
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
prefix = f"{layer_type}_"
original_max_position_embeddings = self.config.rope_parameters[layer_type][
"original_max_position_embeddings"
]

if seq_len > original_max_position_embeddings:
if not hasattr(self, f"{layer_type}_long_inv_freq"):
Expand Down Expand Up @@ -222,7 +224,6 @@ def _compute_dynamic_ntk_parameters(
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_parameters
# For backward compatibility standardize the `rope_parameters_dict` if it uses old format
config.standardize_rope_params()
rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
Expand All @@ -231,23 +232,22 @@ def _compute_dynamic_ntk_parameters(
partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = rope_parameters_dict["factor"]
attention_factor = 1.0 # Unused in this type of RoPE

# seq_len: default to max_position_embeddings, e.g. at init time
if seq_len is None:
seq_len = max_position_embeddings
seq_len = config.max_position_embeddings
elif isinstance(seq_len, torch.Tensor):
seq_len = torch.maximum(
seq_len,
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
)
else:
seq_len = max(seq_len, max_position_embeddings)
seq_len = max(seq_len, config.max_position_embeddings)

# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
return inv_freq, attention_factor

Expand Down Expand Up @@ -291,8 +291,7 @@ def _compute_yarn_parameters(
`mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing
the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor`
will be calculated based on `factor` only.
* `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used
during pretraining. If not provided, the function falls back to `max_position_embeddings`.
* `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining.
* `truncate` (`bool`, *optional*): Whether to truncate the correction range.

Additionally, this function will make use of the following properties if they are found in the config:
Expand Down Expand Up @@ -323,15 +322,13 @@ def _compute_yarn_parameters(
attention_factor = rope_parameters_dict.get("attention_factor")
mscale = rope_parameters_dict.get("mscale")
mscale_all_dim = rope_parameters_dict.get("mscale_all_dim")
original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]

# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in rope_parameters_dict:
original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]
# NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field
# containing the pretrained value. They use the ratio between `max_position_embeddings` and this value
# to compute the default attention scaling factor, instead of using `factor`.
if factor is None:
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings

def get_mscale(scale, mscale=1):
if scale <= 1:
Expand Down Expand Up @@ -439,7 +436,6 @@ def _compute_longrope_parameters(
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_parameters
# For backward compatibility standardize the `rope_parameters_dict` if it uses old format
config.standardize_rope_params()
rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
Expand All @@ -453,14 +449,13 @@ def _compute_longrope_parameters(
short_factor = rope_parameters_dict["short_factor"]
factor = rope_parameters_dict.get("factor")
attention_factor = rope_parameters_dict.get("attention_factor")
original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]

# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None):
if factor is None:
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings

# Sets the attention factor as suggested in the paper
if attention_factor is None:
Expand Down Expand Up @@ -586,7 +581,7 @@ class RopeParameters(TypedDict, total=False):
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
original_max_position_embeddings (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
Comment on lines 583 to 585
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dynamic uses config.max_position_embedding and doesn't require us to set explicit original_max_position_embeddings in rope dict

attention_factor (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
Expand Down Expand Up @@ -640,6 +635,7 @@ def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: Optional[s

# Standardize and validate the correctness of rotary position embeddings parameters
self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta))

if "partial_rotary_factor" in kwargs:
self.rope_parameters.setdefault("partial_rotary_factor", kwargs["partial_rotary_factor"])
ignore_keys_at_rope_validation = {"partial_rotary_factor"}
Expand All @@ -664,14 +660,30 @@ def standardize_rope_params(self):
rope_parameters.setdefault("rope_theta", rope_theta)
if partial_rotary_factor is not None:
rope_parameters["partial_rotary_factor"] = partial_rotary_factor

# Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling
if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]:
if hasattr(self, "original_max_position_embeddings"):
# NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field
# containing the pretrained value outside rope parameters. This is an exception case where we
# give priority to `self.original_max_position_embeddings
self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings
else:
self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings)

# Case 2: different RoPE for each layer -> several params as nested dict
else:
for layer_type in self.layer_types:
for layer_type in set(self.layer_types):
rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default"))
rope_parameters[layer_type].setdefault("rope_theta", rope_theta)
if partial_rotary_factor is not None:
rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor

if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]:
self.rope_parameters[layer_type].setdefault(
"original_max_position_embeddings", self.max_position_embeddings
)

self.rope_parameters = rope_parameters

def validate_rope(self: "PreTrainedConfig", ignore_keys: Optional[set] = None):
Expand Down Expand Up @@ -718,26 +730,24 @@ def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: O
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
required_keys = {"rope_type", "factor"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)

factor = rope_parameters["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
required_keys = {"rope_type", "factor", "rope_theta"}
required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"}
optional_keys = {
"attention_factor",
"beta_fast",
"beta_slow",
"original_max_position_embeddings",
"mscale",
"mscale_all_dim",
"truncate",
}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
Expand Down Expand Up @@ -765,37 +775,24 @@ def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Opt
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)

# Models should set `config.rope_parameters["original_max_position_embeddings"]` to their original (pre-yarn) context
# length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
# However, for BC purposes, we allow the former to be unset.
original_max_position_embeddings = self.rope_parameters.get("original_max_position_embeddings")
if original_max_position_embeddings is not None:
# Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
implicit_factor = self.max_position_embeddings / original_max_position_embeddings
if implicit_factor != factor:
logger.warning_once(
f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match "
"the ratio implicitly set by other parameters (implicit factor = "
"post-yarn context length / pre-yarn context length = "
"config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = "
f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
"behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
)
# No `config.rope_parameters["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
# pre-yarn or the post-yarn context length?
# BC: we assume it is the pre-yarn context length.
else:
# Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
# NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was
# inferred from `max_position_embeddings` during standardization
original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"]
implicit_factor = self.max_position_embeddings / original_max_position_embeddings
if implicit_factor != factor and implicit_factor != 1:
logger.warning_once(
"config.rope_parameters['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
"**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
"config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
"factor) -- we recommend updating both fields for optimal downstream model usage."
f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match "
"the ratio implicitly set by other parameters (implicit factor = "
"post-yarn context length / pre-yarn context length = "
"config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = "
f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
"behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config."
)

def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None):
required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"}
optional_keys = {"attention_factor", "factor"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
Expand All @@ -820,23 +817,23 @@ def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys:
f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}"
)

# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_parameters` and is
# unique to longrope (= undesirable)
if hasattr(self, "original_max_position_embeddings"):
factor = rope_parameters["factor"]
original_max_position_embeddings = rope_parameters["original_max_position_embeddings"]

# Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter is undesirable
if factor is None and original_max_position_embeddings is not None:
logger.warning_once(
"This model has set a `original_max_position_embeddings` field, to be used together with "
"This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with "
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`"
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
"as it is compatible with most model architectures."
)
else:
factor = rope_parameters.get("factor")
elif original_max_position_embeddings is None:
if factor is None:
logger.warning("Missing required keys in `rope_parameters`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

attention_factor = rope_parameters.get("attention_factor")
if attention_factor is not None:
if not isinstance(attention_factor, float) or attention_factor < 0.0:
Expand Down
1 change: 1 addition & 0 deletions tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def test_model_rope_scaling_from_config(self):
def test_model_rope_scaling_frequencies(self):
"""Tests the frequency properties of the different RoPE scaling types on the model RoPE layer."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.layer_types = ["full_attention", "sliding_attention"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the below code sets rope params for two layer types, but the dummy config doesn't always get init with both. This line makes sure that layer types are in line with rope params


# Retrieves the RoPE layer class from the base model class. Uses `.named_modules()` to avoid hardcoding the
# named location of the RoPE layer class.
Expand Down
Loading