-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Modify Muon optimizer #21859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Modify Muon optimizer #21859
Changes from 6 commits
441dbba
f8409e7
8ec182a
202923e
6eac82b
12b7db6
8e0c80b
2f23937
e4d7196
69e0f48
39020ea
d38ddca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -111,11 +111,17 @@ def weight_decay_fn(variable): | |||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _backend_update_step(self, grads, trainable_variables, learning_rate): | ||||||||||||||||||||||||||||||||||||||||||
| trainable_variables = [ | ||||||||||||||||||||||||||||||||||||||||||
| v.value if isinstance(v, backend.Variable) else v | ||||||||||||||||||||||||||||||||||||||||||
| for v in trainable_variables | ||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||
| grads_and_vars = list(zip(grads, trainable_variables)) | ||||||||||||||||||||||||||||||||||||||||||
| new_trainable_variables = [] | ||||||||||||||||||||||||||||||||||||||||||
| for v in trainable_variables: | ||||||||||||||||||||||||||||||||||||||||||
| # add variable.path attribute to new variable | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(v, backend.Variable): | ||||||||||||||||||||||||||||||||||||||||||
| new_v = v.value | ||||||||||||||||||||||||||||||||||||||||||
| new_v.path = v.path | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| new_v = v | ||||||||||||||||||||||||||||||||||||||||||
| new_v.path = v.name | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
| new_trainable_variables.append(new_v) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
| new_trainable_variables = [] | |
| for v in trainable_variables: | |
| # add variable.path attribute to new variable | |
| if isinstance(v, backend.Variable): | |
| new_v = v.value | |
| new_v.path = v.path | |
| else: | |
| new_v = v | |
| new_v.path = v.name | |
| new_trainable_variables.append(new_v) | |
| def _prepare_var(v): | |
| if isinstance(v, backend.Variable): | |
| new_v = v.value | |
| new_v.path = v.path | |
| else: | |
| new_v = v | |
| new_v.path = v.name | |
| return new_v | |
| new_trainable_variables = [_prepare_var(v) for v in trainable_variables] |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer): | |||||||||||||
| The Muon optimizer can use both the Muon update step or the | ||||||||||||||
| AdamW update step based on the following: | ||||||||||||||
|
|
||||||||||||||
| - For any variable that isn't 2D, 3D or 4D, the AdamW step | ||||||||||||||
| - For any variable that isn't 2D or 3D , the AdamW step | ||||||||||||||
|
||||||||||||||
| will be used. This is not configurable. | ||||||||||||||
| - If the argument `exclude_embeddings` (defaults to `True`) is set | ||||||||||||||
| to `True`, the AdamW step will be used. | ||||||||||||||
|
|
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer): | |||||||||||||
| that takes no arguments and returns the actual value to use. | ||||||||||||||
| The exponential decay rate for the 1st moment estimates. Defaults to | ||||||||||||||
| `0.9`. | ||||||||||||||
| adam_beta_2: A float value or a constant float tensor, ora callable | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||
| adam_beta_2: A float value or a constant float tensor, or a callable | ||||||||||||||
| that takes no arguments and returns the actual value to use. | ||||||||||||||
| The exponential decay rate for the 2nd moment estimates. Defaults to | ||||||||||||||
| `0.999`. | ||||||||||||||
| adam_weight_decay: Float. If set, weight decay is applied when using | ||||||||||||||
| the Adam optimizer. | ||||||||||||||
| epsilon: A small constant for numerical stability. This is | ||||||||||||||
| "epsilon hat" in the Kingma and Ba paper | ||||||||||||||
| (in the formula just before Section 2.1), | ||||||||||||||
|
|
@@ -72,13 +74,19 @@ class Muon(optimizer.Optimizer): | |||||||||||||
| ns_steps: Integer, number of Newton-Schulz iterations to run. | ||||||||||||||
| nesterov: Boolean, whether to use Nesterov-style momentum | ||||||||||||||
| {{base_optimizer_keyword_args}} | ||||||||||||||
| `rms_rate`: A trick from https://arxiv.org/abs/2502.16982. | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||
| This parameter can enhance the stability of Muon, | ||||||||||||||
| allowing it to use the same learning rate and weight decay as Adam. | ||||||||||||||
| It is disabled by default. | ||||||||||||||
| If you wish to enable it, it is recommended to set it to `0.2`. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| def __init__( | ||||||||||||||
| self, | ||||||||||||||
| learning_rate=0.001, | ||||||||||||||
| adam_beta_1=0.9, | ||||||||||||||
| adam_beta_2=0.999, | ||||||||||||||
| adam_weight_decay=0.004, | ||||||||||||||
| epsilon=1e-7, | ||||||||||||||
| weight_decay=0.1, | ||||||||||||||
| clipnorm=None, | ||||||||||||||
|
|
@@ -99,6 +107,7 @@ def __init__( | |||||||||||||
| momentum=0.95, | ||||||||||||||
| ns_steps=6, | ||||||||||||||
| nesterov=True, | ||||||||||||||
| rms_rate=None, | ||||||||||||||
| **kwargs, | ||||||||||||||
| ): | ||||||||||||||
| super().__init__( | ||||||||||||||
|
|
@@ -127,6 +136,8 @@ def __init__( | |||||||||||||
| self.nesterov = nesterov | ||||||||||||||
| self.exclude_embeddings = exclude_embeddings | ||||||||||||||
| self.exclude_layers = exclude_layers or [] | ||||||||||||||
| self.adam_weight_decay = adam_weight_decay | ||||||||||||||
| self.rms_rate = rms_rate | ||||||||||||||
|
|
||||||||||||||
| def _should_use_adamw(self, variable): | ||||||||||||||
| # To use it with 4D convolutional filters, | ||||||||||||||
|
|
@@ -185,17 +196,15 @@ def update_step(self, gradient, variable, learning_rate): | |||||||||||||
| def _muon_update_step(self, gradient, variable, lr): | ||||||||||||||
| m = self.adam_momentums[variable.path] | ||||||||||||||
| self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) | ||||||||||||||
| shape = variable.shape | ||||||||||||||
| if self.nesterov: | ||||||||||||||
| g = ops.add(gradient, self.momentum * m) | ||||||||||||||
| else: | ||||||||||||||
| g = m | ||||||||||||||
| update = self.zeropower_via_newtonschulz5(g, self.ns_steps) | ||||||||||||||
|
|
||||||||||||||
| self.assign_sub( | ||||||||||||||
| variable, | ||||||||||||||
| lr | ||||||||||||||
| * self.zeropower_via_newtonschulz5(g, self.ns_steps) | ||||||||||||||
| * max(1, shape[0] / shape[1]) ** 0.5, | ||||||||||||||
| lr * self.rms_macthing(update), | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| def _adamw_update_step(self, gradient, variable, learning_rate): | ||||||||||||||
|
|
@@ -239,6 +248,20 @@ def transpose_last_axis(self, X): | |||||||||||||
| X = ops.transpose(X, temp_order) | ||||||||||||||
| return X | ||||||||||||||
|
|
||||||||||||||
| def rms_macthing(self, x): | ||||||||||||||
|
||||||||||||||
| """ | ||||||||||||||
| You can check the details at https://arxiv.org/pdf/2502.16982. | ||||||||||||||
| For a 2D matrix of size m,the analytical solution provided in the paper | ||||||||||||||
| rate * x * sqrt(max(n,m)) | ||||||||||||||
|
Comment on lines
+255
to
+257
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring has a couple of issues:
Suggested change
|
||||||||||||||
| """ | ||||||||||||||
| if self.rms_rate is None or len(x.shape) != 2: | ||||||||||||||
| # KellerJordan version in muon github | ||||||||||||||
| # https://github.com/KellerJordan/Muon | ||||||||||||||
| return x * max(1, x.shape[-2] / x.shape[-1]) ** 0.5 | ||||||||||||||
| # moonlight version | ||||||||||||||
| # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py | ||||||||||||||
| return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate | ||||||||||||||
|
|
||||||||||||||
| def zeropower_via_newtonschulz5(self, x, steps: int): | ||||||||||||||
| """We apply the Newton-Schulz iteration to compute matrix G. | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -268,6 +291,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int): | |||||||||||||
| x = self.transpose_last_axis(x) | ||||||||||||||
| return x | ||||||||||||||
|
|
||||||||||||||
| def _apply_weight_decay(self, variables): | ||||||||||||||
| for variable in variables: | ||||||||||||||
| if self._use_weight_decay(variable): | ||||||||||||||
| if self._should_use_adamw(variable): | ||||||||||||||
| if self.adam_weight_decay is None: | ||||||||||||||
| continue | ||||||||||||||
| wd = ops.cast(self.adam_weight_decay, variable.dtype) | ||||||||||||||
| else: | ||||||||||||||
| if self.weight_decay is None: | ||||||||||||||
| continue | ||||||||||||||
| wd = ops.cast(self.weight_decay, variable.dtype) | ||||||||||||||
| lr = ops.cast(self.learning_rate, variable.dtype) | ||||||||||||||
| variable.assign(variable - variable * wd * lr) | ||||||||||||||
|
|
||||||||||||||
| def get_config(self): | ||||||||||||||
| config = super().get_config() | ||||||||||||||
| config.update( | ||||||||||||||
|
|
@@ -284,6 +321,8 @@ def get_config(self): | |||||||||||||
| "ns_steps": self.ns_steps, | ||||||||||||||
| "nesterov": self.nesterov, | ||||||||||||||
| "exclude_embeddings": self.exclude_embeddings, | ||||||||||||||
| "adam_weight_decay": self.adam_weight_decay, | ||||||||||||||
| "rms_rate": self.rms_rate, | ||||||||||||||
| } | ||||||||||||||
| ) | ||||||||||||||
| return config | ||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,7 @@ | ||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import keras | ||||||||||||||||||||||
| from keras.src import backend | ||||||||||||||||||||||
| from keras.src import ops | ||||||||||||||||||||||
| from keras.src import testing | ||||||||||||||||||||||
|
|
@@ -81,3 +83,57 @@ def test_clip_value(self): | |||||||||||||||||||||
| grad = [np.array([100.0, 100.0])] | ||||||||||||||||||||||
| clipped_grad = optimizer._clip_gradients(grad) | ||||||||||||||||||||||
| self.assertAllClose(clipped_grad[0], [1.0, 1.0]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_muon_weight_decay(self): | ||||||||||||||||||||||
| variable = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) | ||||||||||||||||||||||
| weight_decay = 0.01 | ||||||||||||||||||||||
| except_varable = variable - variable * weight_decay | ||||||||||||||||||||||
| optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay) | ||||||||||||||||||||||
| optimizer._apply_weight_decay([variable]) | ||||||||||||||||||||||
| self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4) | ||||||||||||||||||||||
|
||||||||||||||||||||||
| except_varable = variable - variable * weight_decay | |
| optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay) | |
| optimizer._apply_weight_decay([variable]) | |
| self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4) | |
| expected_variable = variable - variable * weight_decay | |
| optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay) | |
| optimizer._apply_weight_decay([variable]) | |
| self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the variable name except_varable. It should be expected_variable for clarity.
| except_varable = variable - variable * weight_decay | |
| optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay) | |
| optimizer._apply_weight_decay([variable]) | |
| self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4) | |
| expected_variable = variable - variable * weight_decay | |
| optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay) | |
| optimizer._apply_weight_decay([variable]) | |
| self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -113,8 +113,9 @@ def error_handler(*args, **kwargs): | |||||||||||||||
| return fn(*args, **kwargs) | ||||||||||||||||
|
|
||||||||||||||||
| filtered_tb = None | ||||||||||||||||
| return fn(*args, **kwargs) | ||||||||||||||||
| try: | ||||||||||||||||
| return fn(*args, **kwargs) | ||||||||||||||||
| pass | ||||||||||||||||
|
||||||||||||||||
| return fn(*args, **kwargs) | |
| try: | |
| return fn(*args, **kwargs) | |
| pass | |
| try: | |
| return fn(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement does nothing -- it doesn't create a copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just my personal coding habit to keep things consistent, even though it doesn’t actually serve a practical purpose. It’s just a small part of my style. Do you think it needs to be changed?