diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index de27e5f8bd20..759aa5db6b37 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -45,17 +45,19 @@ def dynamic_rope_update(rope_forward): def longrope_frequency_update(self, position_ids, device, layer_type=None): """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" seq_len = torch.max(position_ids) + 1 - original_max_position_embeddings = getattr( - self.config, "original_max_position_embeddings", self.config.max_position_embeddings - ) + if layer_type is None: rope_type = self.rope_type original_inv_freq = self.original_inv_freq prefix = "" + original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"] else: rope_type = self.rope_type[layer_type] original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") prefix = f"{layer_type}_" + original_max_position_embeddings = self.config.rope_parameters[layer_type][ + "original_max_position_embeddings" + ] if seq_len > original_max_position_embeddings: if not hasattr(self, f"{layer_type}_long_inv_freq"): @@ -222,7 +224,6 @@ def _compute_dynamic_ntk_parameters( Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - # TODO (joao): use the new `original_max_position_embeddings` from rope_parameters # For backward compatibility standardize the `rope_parameters_dict` if it uses old format config.standardize_rope_params() rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters @@ -231,23 +232,22 @@ def _compute_dynamic_ntk_parameters( partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) - max_position_embeddings = config.max_position_embeddings factor = rope_parameters_dict["factor"] attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time if seq_len is None: - seq_len = max_position_embeddings + seq_len = config.max_position_embeddings elif isinstance(seq_len, torch.Tensor): seq_len = torch.maximum( seq_len, - torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), + torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), ) else: - seq_len = max(seq_len, max_position_embeddings) + seq_len = max(seq_len, config.max_position_embeddings) # Compute the inverse frequencies - base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) return inv_freq, attention_factor @@ -291,8 +291,7 @@ def _compute_yarn_parameters( `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be calculated based on `factor` only. - * `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used - during pretraining. If not provided, the function falls back to `max_position_embeddings`. + * `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining. * `truncate` (`bool`, *optional*): Whether to truncate the correction range. Additionally, this function will make use of the following properties if they are found in the config: @@ -323,15 +322,13 @@ def _compute_yarn_parameters( attention_factor = rope_parameters_dict.get("attention_factor") mscale = rope_parameters_dict.get("mscale") mscale_all_dim = rope_parameters_dict.get("mscale_all_dim") + original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] - # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a - # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two - # values to compute the default attention scaling factor, instead of using `factor`. - if "original_max_position_embeddings" in rope_parameters_dict: - original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] + # NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field + # containing the pretrained value. They use the ratio between `max_position_embeddings` and this value + # to compute the default attention scaling factor, instead of using `factor`. + if factor is None: factor = config.max_position_embeddings / original_max_position_embeddings - else: - original_max_position_embeddings = config.max_position_embeddings def get_mscale(scale, mscale=1): if scale <= 1: @@ -439,7 +436,6 @@ def _compute_longrope_parameters( Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ - # TODO (joao): use the new `original_max_position_embeddings` from rope_parameters # For backward compatibility standardize the `rope_parameters_dict` if it uses old format config.standardize_rope_params() rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters @@ -453,14 +449,13 @@ def _compute_longrope_parameters( short_factor = rope_parameters_dict["short_factor"] factor = rope_parameters_dict.get("factor") attention_factor = rope_parameters_dict.get("attention_factor") + original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. - if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None): + if factor is None: factor = config.max_position_embeddings / original_max_position_embeddings - else: - original_max_position_embeddings = config.max_position_embeddings # Sets the attention factor as suggested in the paper if attention_factor is None: @@ -586,7 +581,7 @@ class RopeParameters(TypedDict, total=False): most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. original_max_position_embeddings (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. attention_factor (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention @@ -640,6 +635,7 @@ def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: Optional[s # Standardize and validate the correctness of rotary position embeddings parameters self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + if "partial_rotary_factor" in kwargs: self.rope_parameters.setdefault("partial_rotary_factor", kwargs["partial_rotary_factor"]) ignore_keys_at_rope_validation = {"partial_rotary_factor"} @@ -664,14 +660,30 @@ def standardize_rope_params(self): rope_parameters.setdefault("rope_theta", rope_theta) if partial_rotary_factor is not None: rope_parameters["partial_rotary_factor"] = partial_rotary_factor + + # Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling + if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]: + if hasattr(self, "original_max_position_embeddings"): + # NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field + # containing the pretrained value outside rope parameters. This is an exception case where we + # give priority to `self.original_max_position_embeddings + self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings + else: + self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings) + # Case 2: different RoPE for each layer -> several params as nested dict else: - for layer_type in self.layer_types: + for layer_type in set(self.layer_types): rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default")) rope_parameters[layer_type].setdefault("rope_theta", rope_theta) if partial_rotary_factor is not None: rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor + if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]: + self.rope_parameters[layer_type].setdefault( + "original_max_position_embeddings", self.max_position_embeddings + ) + self.rope_parameters = rope_parameters def validate_rope(self: "PreTrainedConfig", ignore_keys: Optional[set] = None): @@ -718,26 +730,24 @@ def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: O logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None): - # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` - optional_keys = {"original_max_position_embeddings"} required_keys = {"rope_type", "factor"} received_keys = set(rope_parameters.keys()) rope_type = rope_parameters["rope_type"] - self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_parameters["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None): - required_keys = {"rope_type", "factor", "rope_theta"} + required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"} optional_keys = { "attention_factor", "beta_fast", "beta_slow", - "original_max_position_embeddings", "mscale", "mscale_all_dim", + "truncate", } received_keys = set(rope_parameters.keys()) rope_type = rope_parameters["rope_type"] @@ -765,37 +775,24 @@ def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Opt f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) - # Models should set `config.rope_parameters["original_max_position_embeddings"]` to their original (pre-yarn) context - # length, with `config.max_position_embeddings` corresponding to their post-yarn context length. - # However, for BC purposes, we allow the former to be unset. - original_max_position_embeddings = self.rope_parameters.get("original_max_position_embeddings") - if original_max_position_embeddings is not None: - # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. - implicit_factor = self.max_position_embeddings / original_max_position_embeddings - if implicit_factor != factor: - logger.warning_once( - f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match " - "the ratio implicitly set by other parameters (implicit factor = " - "post-yarn context length / pre-yarn context length = " - "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = " - f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " - "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config." - ) - # No `config.rope_parameters["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the - # pre-yarn or the post-yarn context length? - # BC: we assume it is the pre-yarn context length. - else: + # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. + # NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was + # inferred from `max_position_embeddings` during standardization + original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"] + implicit_factor = self.max_position_embeddings / original_max_position_embeddings + if implicit_factor != factor and implicit_factor != 1: logger.warning_once( - "config.rope_parameters['original_max_position_embeddings'], the pre-yarn context length, is unset. We will " - "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect " - "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * " - "factor) -- we recommend updating both fields for optimal downstream model usage." + f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match " + "the ratio implicitly set by other parameters (implicit factor = " + "post-yarn context length / pre-yarn context length = " + "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = " + f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " + "behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config." ) def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None): - required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta"} - # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` - optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "factor"} received_keys = set(rope_parameters.keys()) rope_type = rope_parameters["rope_type"] self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) @@ -820,23 +817,23 @@ def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}" ) - # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over - # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_parameters` and is - # unique to longrope (= undesirable) - if hasattr(self, "original_max_position_embeddings"): + factor = rope_parameters["factor"] + original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] + + # Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter is undesirable + if factor is None and original_max_position_embeddings is not None: logger.warning_once( - "This model has set a `original_max_position_embeddings` field, to be used together with " + "This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with " "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`" "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " "as it is compatible with most model architectures." ) - else: - factor = rope_parameters.get("factor") + elif original_max_position_embeddings is None: if factor is None: logger.warning("Missing required keys in `rope_parameters`: 'factor'") elif not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") - attention_factor = rope_parameters.get("attention_factor") if attention_factor is not None: if not isinstance(attention_factor, float) or attention_factor < 0.0: diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 17e9d9991bd4..ea8f40480765 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -158,6 +158,7 @@ def test_model_rope_scaling_from_config(self): def test_model_rope_scaling_frequencies(self): """Tests the frequency properties of the different RoPE scaling types on the model RoPE layer.""" config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.layer_types = ["full_attention", "sliding_attention"] # Retrieves the RoPE layer class from the base model class. Uses `.named_modules()` to avoid hardcoding the # named location of the RoPE layer class. diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index 730c4967368e..8bc9458b4ddc 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -101,17 +101,15 @@ def test_yarn_original_original_max_position_embeddings_validation(self): with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: config.validate_rope() - # bad rope config, no `original_max_position_embeddings` -> warning + # bad rope config, no `original_max_position_embeddings` -> raise error rope_config = { "rope_type": "yarn", "rope_theta": 10000.0, "factor": 2.0, } config.rope_parameters = rope_config - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + with self.assertRaises(KeyError): config.validate_rope() - self.assertEqual(len(logs.output), 1) - self.assertIn("is unset", logs.output[0]) # bad rope config, bad implicit fator -> warning rope_config = { @@ -338,9 +336,8 @@ def test_longrope_rope_numerically(self): default_inv_freq, _ = rope_fn(config=config, device=torch_device) # Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default -- - # `math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))` + # `math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))` rope_fn = ROPE_INIT_FUNCTIONS["longrope"] - max_position_embeddings = config.max_position_embeddings for factor in (2.0, 10.0, 20.0): config.rope_parameters = { "rope_type": "longrope", @@ -350,7 +347,9 @@ def test_longrope_rope_numerically(self): "long_factor": long_factor, } _, attention_scale = rope_fn(config=config, device=torch_device) - self.assertEqual(attention_scale, math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))) + self.assertEqual( + attention_scale, math.sqrt(1 + math.log(factor) / math.log(config.max_position_embeddings)) + ) config.rope_parameters = { "rope_type": "longrope",