Skip to content

[Torch backend] Layer loses keyword arguments when wrapped by RematScope #21861

@pass-lin

Description

@pass-lin

Description
When the Torch backend is used together with keras.RematScope, any layer whose call method is called with keyword arguments loses those arguments and receives the default values instead.

Other backends (JAX/TF) are not affected


Reproducible example

import os
os.environ["KERAS_BACKEND"] = "torch"

import keras
dim = 4
class RwkvBlock(keras.layers.Layer):
    def call(self, x, v_first=None):
        print(f"[{self.name}]  v_first={v_first is None}")
        v_first  = x
        return x, v_first 
    def compute_output_shape(self,**kwargs):
        return [[None,dim],[None,dim]]

def build_model():
    inputs = keras.Input(shape=(dim,))
    x = inputs
    v_first = None

    for i in range(5):
        x, v_first = RwkvBlock(name=f"block{i}")(x, v_first=v_first)

    model = keras.Model(inputs, x)
    model.build((None, dim))
    return model



model = build_model()
_ = model(keras.ops.ones((1, dim)))


with keras.RematScope(mode="list_of_layers",
    layer_names=[f"block{i}" for i in range(5)]):
    model = build_model()
_ = model(keras.ops.ones((1, dim)))

Output

----- normal call -----
[block0]  v_first=True
[block1]  v_first=False
[block2]  v_first=False
[block3]  v_first=False
[block4]  v_first=False
----- inside RematScope -----
[block0]  v_first=True
[block1]  v_first=True          # <- should be False
[block2]  v_first=True
[block3]  v_first=True
[block4]  v_first=True

Calling the same layer with positional arguments (layer(x, v_first)) works correctly even under remat.


Expected behaviour
The layer should receive the same arguments regardless of whether rematerialisation is enabled or which backend is used.


Environment

  • Keras 3.x (current master)
  • Torch 2.x
  • Python 3.10+

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions