Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions keras/src/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,17 @@ def weight_decay_fn(variable):
)

def _backend_update_step(self, grads, trainable_variables, learning_rate):
trainable_variables = [
v.value if isinstance(v, backend.Variable) else v
for v in trainable_variables
]
grads_and_vars = list(zip(grads, trainable_variables))
new_trainable_variables = []
for v in trainable_variables:
# add variable.path attribute to new variable
if isinstance(v, backend.Variable):
new_v = v.value
new_v.path = v.path
else:
new_v = v
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement does nothing -- it doesn't create a copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement does nothing -- it doesn't create a copy

This is just my personal coding habit to keep things consistent, even though it doesn’t actually serve a practical purpose. It’s just a small part of my style. Do you think it needs to be changed?

new_v.path = v.name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems dangerous, path and name are different concepts and you can't just overwrite the path attribute in this way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems dangerous, path and name are different concepts and you can't just overwrite the path attribute in this way

If we don’t do this, TensorFlow will still report an error for tf.variable. In the TensorFlow backend environment, assuming that all input data are Keras variables (keras.variable) seems to be a reasonable assumption. According to the current handling method, TensorFlow’s tf.variable no longer uses the path attribute but instead uses the name attribute as the path (path). In fact, in the current version of TensorFlow, this approach has been widely adopted.

For example, running the following code:

import keras
model = keras.Sequential([
    keras.layers.Input(shape=(10,)),
    keras.layers.Dense(5),
    keras.layers.Dense(10),
    keras.layers.Dense(1, name="last")
])
for w in model.weights:
    print(w.path, w.value.name)

The output is as follows:

sequential_1/dense_2/kernel sequential_1/dense_2/kernel:0
sequential_1/dense_2/bias sequential_1/dense_2/bias:0
sequential_1/dense_3/kernel sequential_1/dense_3/kernel:0
sequential_1/dense_3/bias sequential_1/dense_3/bias:0
sequential_1/last/kernel sequential_1/last/kernel:0
sequential_1/last/bias sequential_1/last/bias:0

From the code and output above, it can be seen that the name attribute of TensorFlow variables has actually taken on the role of a path identifier. Therefore, assuming that all input data are Keras variables is reasonable in the current implementation of TensorFlow and is consistent with TensorFlow’s existing behavior.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the detailed comment here: #21797 (review)

This describes how to reliably code this by inspecting path within build. In update_step, what should be used is self._get_variable_index(var). This utility was created specially to overcome this problem.

You should not overwrite name because some users have workflows that depend on it.

new_trainable_variables.append(new_v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation is much clearer and more robust for ensuring the path attribute is correctly propagated to the TensorFlow variables. However, it can be slightly simplified by using a helper function within a list comprehension, which could make the intent even more direct. This is a minor suggestion for code style and readability.

Suggested change
new_trainable_variables = []
for v in trainable_variables:
# add variable.path attribute to new variable
if isinstance(v, backend.Variable):
new_v = v.value
new_v.path = v.path
else:
new_v = v
new_v.path = v.name
new_trainable_variables.append(new_v)
def _prepare_var(v):
if isinstance(v, backend.Variable):
new_v = v.value
new_v.path = v.path
else:
new_v = v
new_v.path = v.name
return new_v
new_trainable_variables = [_prepare_var(v) for v in trainable_variables]

grads_and_vars = list(zip(grads, new_trainable_variables))
grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)
tf.__internal__.distribute.interim.maybe_merge_call(
self._distributed_tf_update_step,
Expand Down
51 changes: 45 additions & 6 deletions keras/src/optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
The Muon optimizer can use both the Muon update step or the
AdamW update step based on the following:

- For any variable that isn't 2D, 3D or 4D, the AdamW step
- For any variable that isn't 2D or 3D , the AdamW step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor formatting issue in the docstring. The space before the comma should be removed to improve readability.

Suggested change
- For any variable that isn't 2D or 3D , the AdamW step
- For any variable that isn't 2D or 3D, the AdamW step

will be used. This is not configurable.
- If the argument `exclude_embeddings` (defaults to `True`) is set
to `True`, the AdamW step will be used.
Expand All @@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
that takes no arguments and returns the actual value to use.
The exponential decay rate for the 1st moment estimates. Defaults to
`0.9`.
adam_beta_2: A float value or a constant float tensor, ora callable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the docstring. ora should be or a.

Suggested change
adam_beta_2: A float value or a constant float tensor, ora callable
adam_beta_2: A float value or a constant float tensor, or a callable

adam_beta_2: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use.
The exponential decay rate for the 2nd moment estimates. Defaults to
`0.999`.
adam_weight_decay: Float. If set, weight decay is applied when using
the Adam optimizer.
epsilon: A small constant for numerical stability. This is
"epsilon hat" in the Kingma and Ba paper
(in the formula just before Section 2.1),
Expand All @@ -72,13 +74,19 @@ class Muon(optimizer.Optimizer):
ns_steps: Integer, number of Newton-Schulz iterations to run.
nesterov: Boolean, whether to use Nesterov-style momentum
{{base_optimizer_keyword_args}}
`rms_rate`: A trick from https://arxiv.org/abs/2502.16982.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The arXiv link appears to have a typo in the year. It points to 2502, but it should likely be 2024. Please verify and correct the link.

Suggested change
`rms_rate`: A trick from https://arxiv.org/abs/2502.16982.
`rms_rate`: A trick from https://arxiv.org/abs/2402.16982.

This parameter can enhance the stability of Muon,
allowing it to use the same learning rate and weight decay as Adam.
It is disabled by default.
If you wish to enable it, it is recommended to set it to `0.2`.
"""

def __init__(
self,
learning_rate=0.001,
adam_beta_1=0.9,
adam_beta_2=0.999,
adam_weight_decay=0.004,
epsilon=1e-7,
weight_decay=0.1,
clipnorm=None,
Expand All @@ -99,6 +107,7 @@ def __init__(
momentum=0.95,
ns_steps=6,
nesterov=True,
rms_rate=None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -127,6 +136,8 @@ def __init__(
self.nesterov = nesterov
self.exclude_embeddings = exclude_embeddings
self.exclude_layers = exclude_layers or []
self.adam_weight_decay = adam_weight_decay
self.rms_rate = rms_rate

def _should_use_adamw(self, variable):
# To use it with 4D convolutional filters,
Expand Down Expand Up @@ -185,17 +196,15 @@ def update_step(self, gradient, variable, learning_rate):
def _muon_update_step(self, gradient, variable, lr):
m = self.adam_momentums[variable.path]
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
shape = variable.shape
if self.nesterov:
g = ops.add(gradient, self.momentum * m)
else:
g = m
update = self.zeropower_via_newtonschulz5(g, self.ns_steps)

self.assign_sub(
variable,
lr
* self.zeropower_via_newtonschulz5(g, self.ns_steps)
* max(1, shape[0] / shape[1]) ** 0.5,
lr * self.rms_macthing(update),
)

def _adamw_update_step(self, gradient, variable, learning_rate):
Expand Down Expand Up @@ -239,6 +248,20 @@ def transpose_last_axis(self, X):
X = ops.transpose(X, temp_order)
return X

def rms_macthing(self, x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the method name. It should be rms_matching instead of rms_macthing. Please correct it here and at all call sites (in _muon_update_step and the test files) for consistency and clarity.

Suggested change
def rms_macthing(self, x):
def rms_matching(self, x):

"""
You can check the details at https://arxiv.org/pdf/2502.16982.
For a 2D matrix of size m,the analytical solution provided in the paper
rate * x * sqrt(max(n,m))
Comment on lines +255 to +257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring has a couple of issues:

  1. The arXiv link seems to have a typo in the year. It points to 2502, but it should likely be 2024. Please verify and correct the link.
  2. There's a typo on the next line: m,the should be m, the.
Suggested change
You can check the details at https://arxiv.org/pdf/2502.16982.
For a 2D matrix of size m,the analytical solution provided in the paper
rate * x * sqrt(max(n,m))
You can check the details at https://arxiv.org/pdf/2402.16982.
For a 2D matrix of size m, the analytical solution provided in the paper
rate * x * sqrt(max(n,m))

"""
if self.rms_rate is None or len(x.shape) != 2:
# KellerJordan version in muon github
# https://github.com/KellerJordan/Muon
return x * max(1, x.shape[-2] / x.shape[-1]) ** 0.5
# moonlight version
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate

def zeropower_via_newtonschulz5(self, x, steps: int):
"""We apply the Newton-Schulz iteration to compute matrix G.

Expand Down Expand Up @@ -268,6 +291,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
x = self.transpose_last_axis(x)
return x

def _apply_weight_decay(self, variables):
for variable in variables:
if self._use_weight_decay(variable):
if self._should_use_adamw(variable):
if self.adam_weight_decay is None:
continue
wd = ops.cast(self.adam_weight_decay, variable.dtype)
else:
if self.weight_decay is None:
continue
wd = ops.cast(self.weight_decay, variable.dtype)
lr = ops.cast(self.learning_rate, variable.dtype)
variable.assign(variable - variable * wd * lr)

def get_config(self):
config = super().get_config()
config.update(
Expand All @@ -284,6 +321,8 @@ def get_config(self):
"ns_steps": self.ns_steps,
"nesterov": self.nesterov,
"exclude_embeddings": self.exclude_embeddings,
"adam_weight_decay": self.adam_weight_decay,
"rms_rate": self.rms_rate,
}
)
return config
56 changes: 56 additions & 0 deletions keras/src/optimizers/muon_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pytest

import keras
from keras.src import backend
from keras.src import ops
from keras.src import testing
Expand Down Expand Up @@ -81,3 +83,57 @@ def test_clip_value(self):
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

def test_muon_weight_decay(self):
variable = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
weight_decay = 0.01
except_varable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])
self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the variable name except_varable. It should be expected_variable for clarity.

Suggested change
except_varable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])
self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4)
expected_variable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)


def test_adamw_weight_decay(self):
variable = backend.Variable(2.0)
weight_decay = 0.01
except_varable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])

self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the variable name except_varable. It should be expected_variable for clarity.

Suggested change
except_varable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])
self.assertAllClose(variable, except_varable, rtol=1e-4, atol=1e-4)
expected_variable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)


def test_rms_matching_none(self):
opt = Muon(rms_rate=None)
x = ops.ones((4, 4))
want = x
self.assertAllClose(opt.rms_macthing(x), want)

def test_rms_matching_2d(self):
opt = Muon(rms_rate=0.2)
x = ops.ones((4, 2))
want = x * 0.2 * 2
self.assertAllClose(opt.rms_macthing(x), want)

def test_rms_matching_3d(self):
opt = Muon(rms_rate=0.1)
x = ops.ones((2, 4, 4))
want = x
self.assertAllClose(opt.rms_macthing(x), want)

@pytest.mark.skipif(
backend.backend() != "tensorflow", reason="Runs only on TF backend"
)
def test_exclude_layers_with_variable_name(self):
optimizer = Muon(learning_rate=0.01, exclude_layers=["last"])

model = keras.Sequential(
[
keras.layers.Dense(5, input_shape=(10,)),
keras.layers.Dense(1, name="last"),
]
)

x_train = np.random.rand(10, 10).astype(np.float32)
y_train = np.random.rand(10, 1).astype(np.float32)

model.compile(optimizer=optimizer, loss="mse")
model.fit(x_train, y_train, epochs=1, batch_size=2, verbose=0)
3 changes: 2 additions & 1 deletion keras/src/utils/traceback_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def error_handler(*args, **kwargs):
return fn(*args, **kwargs)

filtered_tb = None
return fn(*args, **kwargs)
try:
return fn(*args, **kwargs)
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This return statement on line 116 makes the subsequent try...except...finally block unreachable, which effectively disables the traceback filtering feature. This appears to be an unintentional change. The function call should be placed inside the try block to ensure exceptions are caught and their tracebacks can be filtered.

Suggested change
return fn(*args, **kwargs)
try:
return fn(*args, **kwargs)
pass
try:
return fn(*args, **kwargs)

except Exception as e:
filtered_tb = _process_traceback_frames(e.__traceback__)
# To get the full stack trace, call:
Expand Down