-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Move original_max_position_embeddings to rope params
#42513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f2d6567
ac2ca7a
58b4b96
84a190e
d6dd1a2
4bba0e7
f8e2a50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,17 +45,19 @@ def dynamic_rope_update(rope_forward): | |
| def longrope_frequency_update(self, position_ids, device, layer_type=None): | ||
| """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" | ||
| seq_len = torch.max(position_ids) + 1 | ||
| original_max_position_embeddings = getattr( | ||
| self.config, "original_max_position_embeddings", self.config.max_position_embeddings | ||
| ) | ||
|
|
||
| if layer_type is None: | ||
| rope_type = self.rope_type | ||
| original_inv_freq = self.original_inv_freq | ||
| prefix = "" | ||
| original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"] | ||
| else: | ||
| rope_type = self.rope_type[layer_type] | ||
| original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") | ||
| prefix = f"{layer_type}_" | ||
| original_max_position_embeddings = self.config.rope_parameters[layer_type][ | ||
| "original_max_position_embeddings" | ||
| ] | ||
|
|
||
| if seq_len > original_max_position_embeddings: | ||
| if not hasattr(self, f"{layer_type}_long_inv_freq"): | ||
|
|
@@ -222,7 +224,6 @@ def _compute_dynamic_ntk_parameters( | |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the | ||
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). | ||
| """ | ||
| # TODO (joao): use the new `original_max_position_embeddings` from rope_parameters | ||
| # For backward compatibility standardize the `rope_parameters_dict` if it uses old format | ||
| config.standardize_rope_params() | ||
| rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters | ||
|
|
@@ -231,23 +232,22 @@ def _compute_dynamic_ntk_parameters( | |
| partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) | ||
| head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | ||
| dim = int(head_dim * partial_rotary_factor) | ||
| max_position_embeddings = config.max_position_embeddings | ||
| factor = rope_parameters_dict["factor"] | ||
| attention_factor = 1.0 # Unused in this type of RoPE | ||
|
|
||
| # seq_len: default to max_position_embeddings, e.g. at init time | ||
| if seq_len is None: | ||
| seq_len = max_position_embeddings | ||
| seq_len = config.max_position_embeddings | ||
| elif isinstance(seq_len, torch.Tensor): | ||
| seq_len = torch.maximum( | ||
| seq_len, | ||
| torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), | ||
| torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), | ||
| ) | ||
| else: | ||
| seq_len = max(seq_len, max_position_embeddings) | ||
| seq_len = max(seq_len, config.max_position_embeddings) | ||
|
|
||
| # Compute the inverse frequencies | ||
| base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) | ||
| base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) | ||
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) | ||
| return inv_freq, attention_factor | ||
|
|
||
|
|
@@ -291,8 +291,7 @@ def _compute_yarn_parameters( | |
| `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing | ||
| the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor` | ||
| will be calculated based on `factor` only. | ||
| * `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used | ||
| during pretraining. If not provided, the function falls back to `max_position_embeddings`. | ||
| * `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining. | ||
| * `truncate` (`bool`, *optional*): Whether to truncate the correction range. | ||
|
|
||
| Additionally, this function will make use of the following properties if they are found in the config: | ||
|
|
@@ -323,15 +322,13 @@ def _compute_yarn_parameters( | |
| attention_factor = rope_parameters_dict.get("attention_factor") | ||
| mscale = rope_parameters_dict.get("mscale") | ||
| mscale_all_dim = rope_parameters_dict.get("mscale_all_dim") | ||
| original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] | ||
|
|
||
| # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a | ||
| # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two | ||
| # values to compute the default attention scaling factor, instead of using `factor`. | ||
| if "original_max_position_embeddings" in rope_parameters_dict: | ||
| original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] | ||
| # NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field | ||
| # containing the pretrained value. They use the ratio between `max_position_embeddings` and this value | ||
| # to compute the default attention scaling factor, instead of using `factor`. | ||
| if factor is None: | ||
| factor = config.max_position_embeddings / original_max_position_embeddings | ||
| else: | ||
| original_max_position_embeddings = config.max_position_embeddings | ||
|
|
||
| def get_mscale(scale, mscale=1): | ||
| if scale <= 1: | ||
|
|
@@ -439,7 +436,6 @@ def _compute_longrope_parameters( | |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the | ||
| post-processing scaling factor applied to the computed cos/sin. | ||
| """ | ||
| # TODO (joao): use the new `original_max_position_embeddings` from rope_parameters | ||
| # For backward compatibility standardize the `rope_parameters_dict` if it uses old format | ||
| config.standardize_rope_params() | ||
| rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters | ||
|
|
@@ -453,14 +449,13 @@ def _compute_longrope_parameters( | |
| short_factor = rope_parameters_dict["short_factor"] | ||
| factor = rope_parameters_dict.get("factor") | ||
| attention_factor = rope_parameters_dict.get("attention_factor") | ||
| original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] | ||
|
|
||
| # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a | ||
| # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two | ||
| # values to compute the default attention scaling factor, instead of using `factor`. | ||
| if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None): | ||
| if factor is None: | ||
| factor = config.max_position_embeddings / original_max_position_embeddings | ||
| else: | ||
| original_max_position_embeddings = config.max_position_embeddings | ||
|
|
||
| # Sets the attention factor as suggested in the paper | ||
| if attention_factor is None: | ||
|
|
@@ -586,7 +581,7 @@ class RopeParameters(TypedDict, total=False): | |
| most scaling types, a `factor` of x will enable the model to handle sequences of length x * | ||
| original maximum pre-trained length. | ||
| original_max_position_embeddings (`int`, *optional*): | ||
| Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during | ||
| Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during | ||
| pretraining. | ||
|
Comment on lines
583
to
585
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dynamic uses |
||
| attention_factor (`float`, *optional*): | ||
| Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention | ||
|
|
@@ -640,6 +635,7 @@ def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: Optional[s | |
|
|
||
| # Standardize and validate the correctness of rotary position embeddings parameters | ||
| self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) | ||
|
|
||
| if "partial_rotary_factor" in kwargs: | ||
| self.rope_parameters.setdefault("partial_rotary_factor", kwargs["partial_rotary_factor"]) | ||
| ignore_keys_at_rope_validation = {"partial_rotary_factor"} | ||
|
|
@@ -664,14 +660,30 @@ def standardize_rope_params(self): | |
| rope_parameters.setdefault("rope_theta", rope_theta) | ||
| if partial_rotary_factor is not None: | ||
| rope_parameters["partial_rotary_factor"] = partial_rotary_factor | ||
|
|
||
| # Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling | ||
| if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]: | ||
| if hasattr(self, "original_max_position_embeddings"): | ||
| # NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field | ||
| # containing the pretrained value outside rope parameters. This is an exception case where we | ||
| # give priority to `self.original_max_position_embeddings | ||
| self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings | ||
| else: | ||
| self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings) | ||
|
|
||
| # Case 2: different RoPE for each layer -> several params as nested dict | ||
| else: | ||
| for layer_type in self.layer_types: | ||
| for layer_type in set(self.layer_types): | ||
| rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default")) | ||
| rope_parameters[layer_type].setdefault("rope_theta", rope_theta) | ||
| if partial_rotary_factor is not None: | ||
| rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor | ||
|
|
||
| if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]: | ||
| self.rope_parameters[layer_type].setdefault( | ||
| "original_max_position_embeddings", self.max_position_embeddings | ||
| ) | ||
|
|
||
| self.rope_parameters = rope_parameters | ||
|
|
||
| def validate_rope(self: "PreTrainedConfig", ignore_keys: Optional[set] = None): | ||
|
|
@@ -718,26 +730,24 @@ def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: O | |
| logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") | ||
|
|
||
| def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None): | ||
| # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` | ||
| optional_keys = {"original_max_position_embeddings"} | ||
| required_keys = {"rope_type", "factor"} | ||
| received_keys = set(rope_parameters.keys()) | ||
| rope_type = rope_parameters["rope_type"] | ||
| self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) | ||
| self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) | ||
|
|
||
| factor = rope_parameters["factor"] | ||
| if factor is None or not isinstance(factor, float) or factor < 1.0: | ||
| logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") | ||
|
|
||
| def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None): | ||
| required_keys = {"rope_type", "factor", "rope_theta"} | ||
| required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"} | ||
| optional_keys = { | ||
| "attention_factor", | ||
| "beta_fast", | ||
| "beta_slow", | ||
| "original_max_position_embeddings", | ||
| "mscale", | ||
| "mscale_all_dim", | ||
| "truncate", | ||
| } | ||
| received_keys = set(rope_parameters.keys()) | ||
| rope_type = rope_parameters["rope_type"] | ||
|
|
@@ -765,37 +775,24 @@ def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: Opt | |
| f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" | ||
| ) | ||
|
|
||
| # Models should set `config.rope_parameters["original_max_position_embeddings"]` to their original (pre-yarn) context | ||
| # length, with `config.max_position_embeddings` corresponding to their post-yarn context length. | ||
| # However, for BC purposes, we allow the former to be unset. | ||
| original_max_position_embeddings = self.rope_parameters.get("original_max_position_embeddings") | ||
| if original_max_position_embeddings is not None: | ||
| # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. | ||
| implicit_factor = self.max_position_embeddings / original_max_position_embeddings | ||
| if implicit_factor != factor: | ||
| logger.warning_once( | ||
| f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match " | ||
| "the ratio implicitly set by other parameters (implicit factor = " | ||
| "post-yarn context length / pre-yarn context length = " | ||
| "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = " | ||
| f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " | ||
| "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config." | ||
| ) | ||
| # No `config.rope_parameters["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the | ||
| # pre-yarn or the post-yarn context length? | ||
| # BC: we assume it is the pre-yarn context length. | ||
| else: | ||
| # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. | ||
| # NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was | ||
| # inferred from `max_position_embeddings` during standardization | ||
| original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"] | ||
| implicit_factor = self.max_position_embeddings / original_max_position_embeddings | ||
| if implicit_factor != factor and implicit_factor != 1: | ||
| logger.warning_once( | ||
| "config.rope_parameters['original_max_position_embeddings'], the pre-yarn context length, is unset. We will " | ||
| "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect " | ||
| "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * " | ||
| "factor) -- we recommend updating both fields for optimal downstream model usage." | ||
| f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match " | ||
| "the ratio implicitly set by other parameters (implicit factor = " | ||
| "post-yarn context length / pre-yarn context length = " | ||
| "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = " | ||
| f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " | ||
| "behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config." | ||
| ) | ||
|
|
||
| def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: Optional[set] = None): | ||
| required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta"} | ||
| # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` | ||
| optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} | ||
| required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"} | ||
| optional_keys = {"attention_factor", "factor"} | ||
| received_keys = set(rope_parameters.keys()) | ||
| rope_type = rope_parameters["rope_type"] | ||
| self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) | ||
|
|
@@ -820,23 +817,23 @@ def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: | |
| f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}" | ||
| ) | ||
|
|
||
| # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over | ||
| # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_parameters` and is | ||
| # unique to longrope (= undesirable) | ||
| if hasattr(self, "original_max_position_embeddings"): | ||
| factor = rope_parameters["factor"] | ||
| original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] | ||
|
|
||
| # Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over | ||
| # `original_max_position_embeddings` to compute internal variables. The latter is undesirable | ||
| if factor is None and original_max_position_embeddings is not None: | ||
| logger.warning_once( | ||
| "This model has set a `original_max_position_embeddings` field, to be used together with " | ||
| "This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with " | ||
| "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`" | ||
| "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " | ||
| "as it is compatible with most model architectures." | ||
| ) | ||
| else: | ||
| factor = rope_parameters.get("factor") | ||
| elif original_max_position_embeddings is None: | ||
| if factor is None: | ||
| logger.warning("Missing required keys in `rope_parameters`: 'factor'") | ||
| elif not isinstance(factor, float) or factor < 1.0: | ||
| logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") | ||
|
|
||
| attention_factor = rope_parameters.get("attention_factor") | ||
| if attention_factor is not None: | ||
| if not isinstance(attention_factor, float) or attention_factor < 0.0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,6 +158,7 @@ def test_model_rope_scaling_from_config(self): | |
| def test_model_rope_scaling_frequencies(self): | ||
| """Tests the frequency properties of the different RoPE scaling types on the model RoPE layer.""" | ||
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||
| config.layer_types = ["full_attention", "sliding_attention"] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the below code sets rope params for two layer types, but the dummy config doesn't always get init with both. This line makes sure that layer types are in line with rope params |
||
|
|
||
| # Retrieves the RoPE layer class from the base model class. Uses `.named_modules()` to avoid hardcoding the | ||
| # named location of the RoPE layer class. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safe to assume it already exists. We move
original_max_position_embeddingsto its correct location at config init time