Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,6 +2742,42 @@ def round(x, decimals=0):

def tile(x, repeats):
x = convert_to_tensor(x)

# Check if repeats contains only concrete integers
# If so, keep it as a Python list/tuple for better shape inference
try:
if isinstance(repeats, (list, tuple)):
# Try to extract concrete integer values
concrete_repeats = []
for r in repeats:
if isinstance(r, int):
concrete_repeats.append(r)
elif hasattr(r, 'numpy') and r.shape == ():
# Scalar tensor with concrete value
concrete_repeats.append(int(r.numpy()))
else:
# Not a concrete value, fall back to tensor path
concrete_repeats = None
break

if concrete_repeats is not None:
# Use concrete repeats directly for better shape inference
repeats = concrete_repeats
# Pad or trim repeats to match x rank
x_rank = x.shape.rank
if x_rank is not None:
if len(repeats) < x_rank:
repeats = [1] * (x_rank - len(repeats)) + repeats
elif len(repeats) > x_rank:
# Need to reshape x to match repeats length
x_shape_list = [1] * (len(repeats) - x_rank) + [d if d is not None else -1 for d in x.shape.as_list()]
x = tf.reshape(x, x_shape_list)
return tf.tile(x, repeats)
except Exception:
# If anything goes wrong, fall back to original implementation
pass

# Original dynamic implementation for non-concrete repeats
repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
repeats_size = tf.size(repeats)
repeats = tf.pad(
Expand Down
16 changes: 15 additions & 1 deletion keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6411,17 +6411,31 @@ def compute_output_spec(self, x):
repeats = self.repeats
if isinstance(repeats, int):
repeats = [repeats]

# Convert repeats to list if it's a tuple or other iterable
# and extract concrete integer values
if not isinstance(repeats, list):
try:
repeats = list(repeats)
except TypeError:
repeats = [repeats]

if len(x_shape) > len(repeats):
repeats = [1] * (len(x_shape) - len(repeats)) + repeats
else:
x_shape = [1] * (len(repeats) - len(x_shape)) + x_shape

output_shape = []
for x_size, repeat in zip(x_shape, repeats):
# Check if repeat is a concrete integer value
# If it's a symbolic tensor or unknown, we can't infer the size
if x_size is None:
output_shape.append(None)
else:
elif isinstance(repeat, int):
output_shape.append(x_size * repeat)
else:
# repeat is symbolic (e.g., KerasTensor, tf.Tensor, etc.)
output_shape.append(None)
return KerasTensor(output_shape, dtype=x.dtype)


Expand Down
24 changes: 24 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,10 @@ def test_tile(self):
self.assertEqual(knp.tile(x, [2]).shape, (None, 6))
self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6))
self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6))

# Test with multi-dimensional input
x = KerasTensor((None, 3, 2, 2))
self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2))

def test_trace(self):
x = KerasTensor((None, 3, None, 5))
Expand Down Expand Up @@ -9507,3 +9511,23 @@ def call(self, x):
model.compile(jit_compile=jit_compile)

model.predict(np.random.randn(1, 8))

def test_tile_shape_inference_in_layer(self):
"""Test that ops.tile properly infers output shape when used in a Layer.

This is a regression test for issue #20914 where TensorFlow backend
would return all-None shapes when tile was called inside a Layer's
call method with concrete integer repeats.
"""
class TileLayer(keras.layers.Layer):
def call(self, x):
# Use concrete integer repeats
repeats = [1, 2, 1, 1]
return knp.tile(x, repeats)

inputs = keras.Input(shape=(3, 2, 2))
output = TileLayer()(inputs)

# With the fix, output shape should be (None, 6, 2, 2)
# Before the fix, it was (None, None, None, None)
self.assertEqual(output.shape, (None, 6, 2, 2))
Loading