Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 346 additions & 3 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import warnings
from collections.abc import Collection, Iterable
from collections.abc import Collection, Iterable, Sequence
from itertools import pairwise
from textwrap import dedent

import numpy as np
from numpy.lib.array_utils import normalize_axis_index

import pytensor
import pytensor.scalar.basic as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import (
DisconnectedType,
_float_zeros_like,
Expand All @@ -25,7 +27,7 @@
from pytensor.scalar import upcast
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, join, second
from pytensor.tensor.basic import alloc, join, second, split
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
Expand All @@ -43,7 +45,7 @@
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.shape import Shape_i, specify_shape
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
from pytensor.tensor.utils import normalize_reduce_axis
Expand Down Expand Up @@ -2011,6 +2013,345 @@ def concat_with_broadcast(tensor_list, axis=0):
return join(axis, *bcast_tensor_inputs)


class PackHelper:
def __init__(self, axes: int | Sequence[int] | None):
self.axes = tuple(axes) if isinstance(axes, list) else axes
self.op_name = "Pack{axes=" + str(self.axes) + "}"

def _analyze_axes_list(self) -> tuple[int, int, int, int | None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty gnarly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically you asked for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so?

"""
Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as
well as the minimum and maximum number of axes that the inputs can have.

The rules are:
- Axes must be strictly increasing in both the positive and negative parts of the list.
- Negative axes must come after positive axes.
- There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint
(e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]).

Returns
-------
n_axes_before: int
The number of axes before the interval to be raveled.
n_axes_after: int
The number of axes after the interval to be raveled.
min_axes: int
The minimum number of axes that the inputs must have.
max_axes: int or None
The maximum number of axes that the inputs can have, or None if there is no strict maximum. A maximum is
only introduced when it would resolve ambiguities in the interpretation of the axes list. For example,
[2, 3] can be either interpreted as having two ravel intervals [:2] and [4:], which is illegal,
unless 3 is interpreted as -1, which is only possible if all inputs have exactly 4 axes. Likewise,
[-3, -1] can be interpreted as having two ravel intervals [:-3], [-3:], unless -3 is interpreted as 0,
which is only possible if all inputs have exactly 3 axes.
"""
axes = self.axes
if axes is None:
return 0, 0, 0, None

if isinstance(axes, int):
axes = [axes]

if len(set(axes)) != len(axes):
raise ValueError("axes must have no duplicates")
if axes is not None and len(axes) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Ops should be supported in general. Makes writing code easier because you don't have to think about edge case. Empty axes are supported in most Ops that allow variable number of axes (like sum)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should axes = [] do?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tell me what axes does, and I might know

raise ValueError("axes=[] is ambiguous; use None to ravel all")

first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes))
positive_axes = list(axes[:first_negative_idx])
negative_axes = list(axes[first_negative_idx:])

if not all(a < 0 for a in negative_axes):
raise ValueError("Negative axes must come after positive")

def strictly_increasing(s):
return all(b > a for a, b in pairwise(s))

if (positive_axes and not strictly_increasing(positive_axes)) or (
negative_axes and not strictly_increasing(negative_axes)
):
raise ValueError("Axes must be strictly increasing")

def find_gaps(s):
return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1]

pos_gaps = find_gaps(positive_axes)
neg_gaps = find_gaps(negative_axes)
positive_only = positive_axes and not negative_axes
negative_only = negative_axes and not positive_axes
mixed_case = positive_axes and negative_axes

max_axes: int | None = None

n_explicit_holes = len(pos_gaps) + len(neg_gaps)
if n_explicit_holes > 1:
raise ValueError(
"Too many holes in axes list. There can be at most one hole in the axes list, "
"including implict holes resulting from omitting the 0 or -1 axis."
)

if mixed_case:
if pos_gaps or neg_gaps:
raise ValueError(
"Too many holes in axes list. There can be at most one hole in the axes list, "
"including implict holes resulting from omitting the 0 or -1 axis. Because both "
"positive and negative axes are present, there is always assume to be an explit hole "
"between them."
)
n_before = len(positive_axes)
n_after = len(negative_axes)
min_axes = n_before + n_after

if positive_only:
# There are four cases to consider when all axes are positive:
# 0. There are two implicit gaps (0 is not present) and an explicit gap (e.g. [2, 4])
# This case is always illegal, as there is no interpretation that would result in having
# 1. There is only an implicit right hole (e.g. [0, 1])
# This case is legal, and requires no special interpretation. It corresponds to 'i j *' in einops
# 2. There is an explicit internal hole (e.g. [0, 2])
# This case is legal, but requires interpreting the last axis as -1, which introduces a maximum number
# of axes. It corresponds to 'i * j' in einops, and requires at least one input to have 3 dimensions, and
# no input to have more than 3 dimensions.
# 2. The axes start at an index greater than 0, but have no internal holes (e.g. [2, 3])
# This case is legal, but requires flipping the axes to negative indexing, so that the largest axis is
# -1, followed by -2, etc. This introduces a maximum number of axes.
if pos_gaps and positive_axes[0] != 0:
raise ValueError(
"Too many holes in axes list. There can be at most one hole in the axes list, "
"including implict holes resulting from omitting the 0 or -1 axis. In this case, "
"there is an explicit internal hole as well as an implicit left hole."
)

elif positive_axes[0] == 0 and not pos_gaps:
# Case 1: Only right implicit hole. No ambiguities.
n_before = positive_axes[-1] + 1
n_after = 0
min_axes = n_before + n_after
max_axes = None

elif pos_gaps:
# Case 2: Explicit hole in the positives, plus right implicit hole.
split = pos_gaps[0] + 1
n_before = split
n_after = len(positive_axes) - split
min_axes = n_before + n_after

# Close the right implicit hole
max_axes = positive_axes[-1] + 1

else:
# Case 3: Left and right implicit holes, but the right can be closed by flipping to negative axes and
# adding a maximum number of axes.
# Compute min_axes and max_axes under Case 1 of the negative_only scenario, with a max_axes constraint.
max_axes = positive_axes[-1] + 1
n_before = 0
n_after = len(positive_axes)
min_axes = n_before + n_after

if negative_only:
# The same four cases are considered when all axes are negative, but ordering is reversed.
# 0. There are two implicit holes (e.g. [-4, -2])
# This case is always illegal, as there is no interpretation that would result in having only one hole
# in the axis list.
# 1. There is only an implicit left hole (e.g. [-2, -1])
# This case is legal, and requires no special interpretation. It corresponds to '* i j' in einops
# 2. There is an explicit internal hole (e.g. [-3, -1])
# This case is legal, but requires interpreting the smallest axis as 0, which introduces a maximum number
# of axes. It corresponds to '* i j' in einops, and requires at least one input to have 3 dimensions, and
# no input to have more than 3 dimensions.
# 3. The axes end at an index less than -1, but have no internal holes (e.g. [-4, -3]). Flip to positive
# axes, adding a maximum number of axes. Interpret the smallest axis as 0 to resolve ambiguity.
if neg_gaps and negative_axes[-1] != -1:
raise ValueError(
"Too many holes in axes list. There can be at most one hole in the axes list, "
"including implict holes resulting from omitting the 0 or -1 axis. In this case, "
"there is an explicit internal hole as well as an implicit right hole."
)
elif negative_axes[-1] == -1 and not neg_gaps:
# Case 1: No ambiguities, only left implicit hole.
n_before = 0
n_after = len(negative_axes)
min_axes = n_before + n_after
max_axes = None
elif neg_gaps:
# Case 2: Explicit hole in the negatives, plus left implicit hole.
split = neg_gaps[0] + 1
n_before = split
n_after = len(negative_axes) - split
min_axes = n_before + n_after

# Close the left implicit hole
max_axes = abs(min(negative_axes))
else:
# Case 3: Left and right implicit holes, but the left can be closed by flipping to positive axes and
# adding a maximum number of axes.
max_axes = abs(negative_axes[0])
n_before = negative_axes[-1] + max_axes + 1
n_after = 0
min_axes = n_before + n_after

return n_before, n_after, min_axes, max_axes

def validate_inputs(self, tensors: list[TensorLike]):
tensors = [ptb.as_tensor_variable(t) for t in tensors]
_, _, min_axes, max_axes = self._analyze_axes_list()

if min([t.ndim for t in tensors]) < min_axes:
raise ValueError(
f"All input tensors to {self.op_name} must have at least {min_axes} dimensions, but the minimum "
f"number of dimensions found was {min([t.ndim for t in tensors])}."
)

max_ndim = max([t.ndim for t in tensors])
if (
max_axes is not None
and max_ndim > max_axes
and not any(t.ndim == max_axes for t in tensors)
):
raise ValueError(
f"All input tensors to {self.op_name} must have at most {max_axes} dimensions, but the maximum "
f"number of dimensions found was {max_ndim}."
)

def infer_shape(self, tensors: list[TensorLike]) -> tuple[int | None, ...]:
tensors = [ptb.as_tensor_variable(t) for t in tensors]
n_axes_before, n_axes_after, _, _ = self._analyze_axes_list()

def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None:
unique_shapes = {s for s in shapes if s is not None}
if not unique_shapes:
return None
if len(unique_shapes) > 1:
raise ValueError(
f"Input tensors to Pack op have incompatible sizes on dimension {axis} : {shapes}"
)
return unique_shapes.pop()

shapes_to_pack = [
t.type.shape[n_axes_before : t.ndim - n_axes_after] for t in tensors
]
packed_shape = (
None
if any(
shape is None
for packed_shape in shapes_to_pack
for shape in packed_shape
)
else int(sum(np.prod(shapes) for shapes in shapes_to_pack))
)
prefix_shapes = [
_coalesce_dim([t.type.shape[i] for t in tensors], i)
for i in range(n_axes_before)
]
suffix_shapes = [
_coalesce_dim(
[t.type.shape[t.ndim - n_axes_after + i] for t in tensors],
n_axes_before + i,
)
for i in range(n_axes_after)
]

return (*prefix_shapes, packed_shape, *suffix_shapes)


class Pack(OpFromGraph):
"Wrapper for the Pack Op"


def pack(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell how this function works for the docstrings. What happens when I pass inputs with different dimensions, and single/list of axes?

*tensors: TensorVariable, axes: int | Sequence[int] | None = None
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
"""
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
Copy link
Member

@ricardoV94 ricardoV94 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
Concatenate a list of tensors along a subset of consecutive raveled dimensions.


Parameters
----------
tensors: TensorVariable
Tensors to be packed into a single vector.
axes: int or sequence of int, optional
Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
and joined.

Returns
-------
flat_tensor: TensorVariable
A new symbolic variable representing the concatenated 1d vector of all tensor inputs
packed_shapes: list of tuples of TensorVariable
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
"""
if not tensors:
raise ValueError("Cannot pack an empty list of tensors.")

tensors = [ptb.as_tensor(tensor) for tensor in tensors]

pack_helper = PackHelper(axes=axes)

reshaped_tensors = []
tmp_shapes = []

n_axes_before, n_axes_after, _, _ = pack_helper._analyze_axes_list()
pack_helper.validate_inputs(tensors)
output_shape = pack_helper.infer_shape(tensors)

for i, tensor in enumerate(tensors):
shape = tensor.shape
ndim = tensor.ndim
axis_after_packed_axes = ndim - n_axes_after
tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes])
reshaped_tensors.append(
tensor.reshape(
(*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])
)
)

packed_output_tensor = specify_shape(
ptb.join(n_axes_before, *reshaped_tensors), output_shape
)
packed_output_shapes = [
ptb.as_tensor_variable(packed_shape).astype("int64")
for i, packed_shape in enumerate(tmp_shapes)
]

pack_op = Pack(
inputs=tensors,
outputs=[packed_output_tensor, *packed_output_shapes],
name="Pack{axes=" + str(axes) + "}",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why give name instead of just defining the __str__ of the Pack?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I looked at the docs for OpFromGraph and saw there was a name field I could pass

)

outputs = pack_op(*tensors)
return outputs[0], outputs[1:]


def unpack(
flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]]
) -> tuple[TensorVariable, ...]:
"""
Unpack a flat tensor into its original shapes based on the provided packed shapes.

Parameters
----------
flat_tensor: TensorVariable
A 1D tensor that contains the concatenated values of the original tensors.
packed_shapes: list of tuples of TensorVariable
A list of tuples, where each tuple contains the symbolic shape of the original tensors.

Returns
-------
unpacked_tensors: tuple of TensorVariable
A tuple containing the unpacked tensors with their original shapes.
"""
if not packed_shapes:
raise ValueError("Cannot unpack an empty list of shapes.")

n_splits = len(packed_shapes)
split_size = [prod(shape).astype(int) for shape in packed_shapes]
unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits)

return tuple(
[x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)]
)


__all__ = [
"bartlett",
"bincount",
Expand All @@ -2027,10 +2368,12 @@ def concat_with_broadcast(tensor_list, axis=0):
"geomspace",
"linspace",
"logspace",
"pack",
"ravel_multi_index",
"repeat",
"searchsorted",
"squeeze",
"unique",
"unpack",
"unravel_index",
]
Loading