-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Open
Labels
Description
Description
When the Torch backend is used together with keras.RematScope, any layer whose call method is called with keyword arguments loses those arguments and receives the default values instead.
Other backends (JAX/TF) are not affected
Reproducible example
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
dim = 4
class RwkvBlock(keras.layers.Layer):
def call(self, x, v_first=None):
print(f"[{self.name}] v_first={v_first is None}")
v_first = x
return x, v_first
def compute_output_shape(self,**kwargs):
return [[None,dim],[None,dim]]
def build_model():
inputs = keras.Input(shape=(dim,))
x = inputs
v_first = None
for i in range(5):
x, v_first = RwkvBlock(name=f"block{i}")(x, v_first=v_first)
model = keras.Model(inputs, x)
model.build((None, dim))
return model
model = build_model()
_ = model(keras.ops.ones((1, dim)))
with keras.RematScope(mode="list_of_layers",
layer_names=[f"block{i}" for i in range(5)]):
model = build_model()
_ = model(keras.ops.ones((1, dim)))Output
----- normal call -----
[block0] v_first=True
[block1] v_first=False
[block2] v_first=False
[block3] v_first=False
[block4] v_first=False
----- inside RematScope -----
[block0] v_first=True
[block1] v_first=True # <- should be False
[block2] v_first=True
[block3] v_first=True
[block4] v_first=True
Calling the same layer with positional arguments (layer(x, v_first)) works correctly even under remat.
Expected behaviour
The layer should receive the same arguments regardless of whether rematerialisation is enabled or which backend is used.
Environment
- Keras 3.x (current master)
- Torch 2.x
- Python 3.10+