-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
I have the following function to broadcast two arrays to compute the multiplication for all possible mutations. In this function I used both tile and repeat function and I found tile shows inconsistent behaviours between backends.
- Jax:
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_2>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_3>)
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_4>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_5>)
- TensorFlow:
(<KerasTensor shape=(None, None, None, None), dtype=float32, sparse=False, name=keras_tensor_2>, <KerasTensor shape=(None, None, None, None), dtype=float32, sparse=False, name=keras_tensor_3>)
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_4>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_5>)
It seems that TensorFlow could not properly infer the shape of the resulting symbolic tensor.
Another issue is that when using repeats for tile based on the shape of symbolic tensor, TensorFlow still works (although with shapes all None), but Jax raises an error: "'str' object has no attribute '_error_repr'". This issue can be reproduced by replacing repeats with the commented command.
Reproduction Code
import os
os.environ["KERAS_BACKEND"] = "jax"
from keras import ops, layers
from keras import Input
# %%
def broadcast(x1, x2):
"""
Broadcast the shapes of x1 and x2 to allow the computation of cross-interation.
- repeating input1: (a, b) -> (a, b, a, b)
- repeating input2: (c, d) -> (c, c, d, d)
- result: (a, b) * (c, d) = (a * c, b * c, a * d, b * d)
Args:
x1: nD array in shape (..., n1, ny1, nx1) to be broadcasted
x2: nD array in shape (..., n2, ny2, nx2) to be broadcasted
Returns:
Broadcasted nD arrays in shape (..., n1 * n2, ...)
Examples:
>>> import numpy as np
>>> x1 = np.array([[[[0., 1., 2.]],
... [[3., 4., 5.]]]])
>>> x2 = np.array([[[[0., 1., 2.]],
... [[3., 4., 5.]]]])
>>> x1, x2 = broadcast((x1, np.zeros(np.shape(x1))), (x2, np.zeros(np.shape(x2))))
>>> np.array(x1[0]) + 1j * np.array(x1[1])
array([[[[0.+0.j, 1.+0.j, 2.+0.j]],
[[3.+0.j, 4.+0.j, 5.+0.j]],
[[0.+0.j, 1.+0.j, 2.+0.j]],
[[3.+0.j, 4.+0.j, 5.+0.j]]]], dtype=complex64)
>>> np.array(x2[0]) + 1j * np.array(x2[1])
array([[[[0.+0.j, 1.+0.j, 2.+0.j]],
[[0.+0.j, 1.+0.j, 2.+0.j]],
[[3.+0.j, 4.+0.j, 5.+0.j]],
[[3.+0.j, 4.+0.j, 5.+0.j]]]], dtype=complex64)
"""
x1real, x1imag = x1
x2real, x2imag = x2
x1shape = ops.shape(x1real)[-3] # spatial mode dimension
x2shape = ops.shape(x2real)[-3] # spatial mode dimension
x1dims = len(ops.shape(x1real))
# repeats = ops.scatter_update(ops.cast(ops.ones(x1dims), dtype="int32"), [[-3 + x1dims]], [x2shape])
repeats = [1, 2, 1, 1]
x1real = ops.tile(x1real, repeats)
x1imag = ops.tile(x1imag, repeats)
x2real = ops.repeat(x2real, x1shape, axis=-3)
x2imag = ops.repeat(x2imag, x1shape, axis=-3)
return ((x1real, x1imag), (x2real, x2imag))
# %%
class Test(layers.Layer):
def call(self, inputs1, inputs2):
return broadcast(inputs1, inputs2)
test = Test()
# %%
x1 = Input(shape=(3, 2, 2))
x2 = Input(shape=(2, 2, 2))
y1, y2 = test((x1, x1), (x2, x2))
print(y1)
print(y2)Environment
jax 0.5.0
jaxlib 0.5.0
keras 3.8.0
tensorboard 2.18.0