-
Notifications
You must be signed in to change notification settings - Fork 148
Implement pack/unpack helpers #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
9568a83
2e22d34
79d9662
58c0286
5788333
ed60651
0b86851
20ab4e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,12 +1,14 @@ | ||||||
| import warnings | ||||||
| from collections.abc import Collection, Iterable | ||||||
| from collections.abc import Collection, Iterable, Sequence | ||||||
| from itertools import pairwise | ||||||
| from textwrap import dedent | ||||||
|
|
||||||
| import numpy as np | ||||||
| from numpy.lib.array_utils import normalize_axis_index | ||||||
|
|
||||||
| import pytensor | ||||||
| import pytensor.scalar.basic as ps | ||||||
| from pytensor.compile.builders import OpFromGraph | ||||||
| from pytensor.gradient import ( | ||||||
| DisconnectedType, | ||||||
| _float_zeros_like, | ||||||
|
|
@@ -25,7 +27,7 @@ | |||||
| from pytensor.scalar import upcast | ||||||
| from pytensor.tensor import TensorLike, as_tensor_variable | ||||||
| from pytensor.tensor import basic as ptb | ||||||
| from pytensor.tensor.basic import alloc, join, second | ||||||
| from pytensor.tensor.basic import alloc, join, second, split | ||||||
| from pytensor.tensor.exceptions import NotScalarConstantError | ||||||
| from pytensor.tensor.math import abs as pt_abs | ||||||
| from pytensor.tensor.math import all as pt_all | ||||||
|
|
@@ -43,7 +45,7 @@ | |||||
| ) | ||||||
| from pytensor.tensor.math import max as pt_max | ||||||
| from pytensor.tensor.math import sum as pt_sum | ||||||
| from pytensor.tensor.shape import Shape_i | ||||||
| from pytensor.tensor.shape import Shape_i, specify_shape | ||||||
| from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor | ||||||
| from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes | ||||||
| from pytensor.tensor.utils import normalize_reduce_axis | ||||||
|
|
@@ -2011,6 +2013,345 @@ def concat_with_broadcast(tensor_list, axis=0): | |||||
| return join(axis, *bcast_tensor_inputs) | ||||||
|
|
||||||
|
|
||||||
| class PackHelper: | ||||||
| def __init__(self, axes: int | Sequence[int] | None): | ||||||
| self.axes = tuple(axes) if isinstance(axes, list) else axes | ||||||
| self.op_name = "Pack{axes=" + str(self.axes) + "}" | ||||||
|
|
||||||
| def _analyze_axes_list(self) -> tuple[int, int, int, int | None]: | ||||||
| """ | ||||||
| Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as | ||||||
| well as the minimum and maximum number of axes that the inputs can have. | ||||||
|
|
||||||
| The rules are: | ||||||
| - Axes must be strictly increasing in both the positive and negative parts of the list. | ||||||
| - Negative axes must come after positive axes. | ||||||
| - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint | ||||||
| (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]). | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| n_axes_before: int | ||||||
| The number of axes before the interval to be raveled. | ||||||
| n_axes_after: int | ||||||
| The number of axes after the interval to be raveled. | ||||||
| min_axes: int | ||||||
| The minimum number of axes that the inputs must have. | ||||||
| max_axes: int or None | ||||||
| The maximum number of axes that the inputs can have, or None if there is no strict maximum. A maximum is | ||||||
| only introduced when it would resolve ambiguities in the interpretation of the axes list. For example, | ||||||
| [2, 3] can be either interpreted as having two ravel intervals [:2] and [4:], which is illegal, | ||||||
| unless 3 is interpreted as -1, which is only possible if all inputs have exactly 4 axes. Likewise, | ||||||
| [-3, -1] can be interpreted as having two ravel intervals [:-3], [-3:], unless -3 is interpreted as 0, | ||||||
| which is only possible if all inputs have exactly 3 axes. | ||||||
| """ | ||||||
| axes = self.axes | ||||||
| if axes is None: | ||||||
| return 0, 0, 0, None | ||||||
|
|
||||||
| if isinstance(axes, int): | ||||||
| axes = [axes] | ||||||
|
|
||||||
| if len(set(axes)) != len(axes): | ||||||
| raise ValueError("axes must have no duplicates") | ||||||
| if axes is not None and len(axes) == 0: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No Ops should be supported in general. Makes writing code easier because you don't have to think about edge case. Empty axes are supported in most Ops that allow variable number of axes (like sum)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What should
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tell me what axes does, and I might know |
||||||
| raise ValueError("axes=[] is ambiguous; use None to ravel all") | ||||||
|
|
||||||
| first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes)) | ||||||
| positive_axes = list(axes[:first_negative_idx]) | ||||||
| negative_axes = list(axes[first_negative_idx:]) | ||||||
|
|
||||||
| if not all(a < 0 for a in negative_axes): | ||||||
| raise ValueError("Negative axes must come after positive") | ||||||
|
|
||||||
| def strictly_increasing(s): | ||||||
| return all(b > a for a, b in pairwise(s)) | ||||||
|
|
||||||
| if (positive_axes and not strictly_increasing(positive_axes)) or ( | ||||||
| negative_axes and not strictly_increasing(negative_axes) | ||||||
| ): | ||||||
| raise ValueError("Axes must be strictly increasing") | ||||||
|
|
||||||
| def find_gaps(s): | ||||||
| return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1] | ||||||
|
|
||||||
| pos_gaps = find_gaps(positive_axes) | ||||||
| neg_gaps = find_gaps(negative_axes) | ||||||
| positive_only = positive_axes and not negative_axes | ||||||
| negative_only = negative_axes and not positive_axes | ||||||
| mixed_case = positive_axes and negative_axes | ||||||
|
|
||||||
| max_axes: int | None = None | ||||||
|
|
||||||
| n_explicit_holes = len(pos_gaps) + len(neg_gaps) | ||||||
| if n_explicit_holes > 1: | ||||||
| raise ValueError( | ||||||
| "Too many holes in axes list. There can be at most one hole in the axes list, " | ||||||
| "including implict holes resulting from omitting the 0 or -1 axis." | ||||||
| ) | ||||||
|
|
||||||
| if mixed_case: | ||||||
| if pos_gaps or neg_gaps: | ||||||
| raise ValueError( | ||||||
| "Too many holes in axes list. There can be at most one hole in the axes list, " | ||||||
| "including implict holes resulting from omitting the 0 or -1 axis. Because both " | ||||||
| "positive and negative axes are present, there is always assume to be an explit hole " | ||||||
| "between them." | ||||||
| ) | ||||||
| n_before = len(positive_axes) | ||||||
| n_after = len(negative_axes) | ||||||
| min_axes = n_before + n_after | ||||||
|
|
||||||
| if positive_only: | ||||||
| # There are four cases to consider when all axes are positive: | ||||||
| # 0. There are two implicit gaps (0 is not present) and an explicit gap (e.g. [2, 4]) | ||||||
| # This case is always illegal, as there is no interpretation that would result in having | ||||||
| # 1. There is only an implicit right hole (e.g. [0, 1]) | ||||||
| # This case is legal, and requires no special interpretation. It corresponds to 'i j *' in einops | ||||||
| # 2. There is an explicit internal hole (e.g. [0, 2]) | ||||||
| # This case is legal, but requires interpreting the last axis as -1, which introduces a maximum number | ||||||
| # of axes. It corresponds to 'i * j' in einops, and requires at least one input to have 3 dimensions, and | ||||||
| # no input to have more than 3 dimensions. | ||||||
| # 2. The axes start at an index greater than 0, but have no internal holes (e.g. [2, 3]) | ||||||
| # This case is legal, but requires flipping the axes to negative indexing, so that the largest axis is | ||||||
| # -1, followed by -2, etc. This introduces a maximum number of axes. | ||||||
| if pos_gaps and positive_axes[0] != 0: | ||||||
| raise ValueError( | ||||||
| "Too many holes in axes list. There can be at most one hole in the axes list, " | ||||||
| "including implict holes resulting from omitting the 0 or -1 axis. In this case, " | ||||||
| "there is an explicit internal hole as well as an implicit left hole." | ||||||
| ) | ||||||
|
|
||||||
| elif positive_axes[0] == 0 and not pos_gaps: | ||||||
| # Case 1: Only right implicit hole. No ambiguities. | ||||||
| n_before = positive_axes[-1] + 1 | ||||||
| n_after = 0 | ||||||
| min_axes = n_before + n_after | ||||||
| max_axes = None | ||||||
|
|
||||||
| elif pos_gaps: | ||||||
| # Case 2: Explicit hole in the positives, plus right implicit hole. | ||||||
| split = pos_gaps[0] + 1 | ||||||
| n_before = split | ||||||
| n_after = len(positive_axes) - split | ||||||
| min_axes = n_before + n_after | ||||||
|
|
||||||
| # Close the right implicit hole | ||||||
| max_axes = positive_axes[-1] + 1 | ||||||
|
|
||||||
| else: | ||||||
| # Case 3: Left and right implicit holes, but the right can be closed by flipping to negative axes and | ||||||
| # adding a maximum number of axes. | ||||||
| # Compute min_axes and max_axes under Case 1 of the negative_only scenario, with a max_axes constraint. | ||||||
| max_axes = positive_axes[-1] + 1 | ||||||
| n_before = 0 | ||||||
| n_after = len(positive_axes) | ||||||
| min_axes = n_before + n_after | ||||||
|
|
||||||
| if negative_only: | ||||||
| # The same four cases are considered when all axes are negative, but ordering is reversed. | ||||||
| # 0. There are two implicit holes (e.g. [-4, -2]) | ||||||
| # This case is always illegal, as there is no interpretation that would result in having only one hole | ||||||
| # in the axis list. | ||||||
| # 1. There is only an implicit left hole (e.g. [-2, -1]) | ||||||
| # This case is legal, and requires no special interpretation. It corresponds to '* i j' in einops | ||||||
| # 2. There is an explicit internal hole (e.g. [-3, -1]) | ||||||
| # This case is legal, but requires interpreting the smallest axis as 0, which introduces a maximum number | ||||||
| # of axes. It corresponds to '* i j' in einops, and requires at least one input to have 3 dimensions, and | ||||||
| # no input to have more than 3 dimensions. | ||||||
| # 3. The axes end at an index less than -1, but have no internal holes (e.g. [-4, -3]). Flip to positive | ||||||
| # axes, adding a maximum number of axes. Interpret the smallest axis as 0 to resolve ambiguity. | ||||||
| if neg_gaps and negative_axes[-1] != -1: | ||||||
| raise ValueError( | ||||||
| "Too many holes in axes list. There can be at most one hole in the axes list, " | ||||||
| "including implict holes resulting from omitting the 0 or -1 axis. In this case, " | ||||||
| "there is an explicit internal hole as well as an implicit right hole." | ||||||
| ) | ||||||
| elif negative_axes[-1] == -1 and not neg_gaps: | ||||||
| # Case 1: No ambiguities, only left implicit hole. | ||||||
| n_before = 0 | ||||||
| n_after = len(negative_axes) | ||||||
| min_axes = n_before + n_after | ||||||
| max_axes = None | ||||||
| elif neg_gaps: | ||||||
| # Case 2: Explicit hole in the negatives, plus left implicit hole. | ||||||
| split = neg_gaps[0] + 1 | ||||||
| n_before = split | ||||||
| n_after = len(negative_axes) - split | ||||||
| min_axes = n_before + n_after | ||||||
|
|
||||||
| # Close the left implicit hole | ||||||
| max_axes = abs(min(negative_axes)) | ||||||
| else: | ||||||
| # Case 3: Left and right implicit holes, but the left can be closed by flipping to positive axes and | ||||||
| # adding a maximum number of axes. | ||||||
| max_axes = abs(negative_axes[0]) | ||||||
| n_before = negative_axes[-1] + max_axes + 1 | ||||||
| n_after = 0 | ||||||
| min_axes = n_before + n_after | ||||||
|
|
||||||
| return n_before, n_after, min_axes, max_axes | ||||||
|
|
||||||
| def validate_inputs(self, tensors: list[TensorLike]): | ||||||
| tensors = [ptb.as_tensor_variable(t) for t in tensors] | ||||||
| _, _, min_axes, max_axes = self._analyze_axes_list() | ||||||
|
|
||||||
| if min([t.ndim for t in tensors]) < min_axes: | ||||||
| raise ValueError( | ||||||
| f"All input tensors to {self.op_name} must have at least {min_axes} dimensions, but the minimum " | ||||||
| f"number of dimensions found was {min([t.ndim for t in tensors])}." | ||||||
| ) | ||||||
|
|
||||||
| max_ndim = max([t.ndim for t in tensors]) | ||||||
| if ( | ||||||
| max_axes is not None | ||||||
| and max_ndim > max_axes | ||||||
| and not any(t.ndim == max_axes for t in tensors) | ||||||
| ): | ||||||
| raise ValueError( | ||||||
| f"All input tensors to {self.op_name} must have at most {max_axes} dimensions, but the maximum " | ||||||
| f"number of dimensions found was {max_ndim}." | ||||||
| ) | ||||||
|
|
||||||
| def infer_shape(self, tensors: list[TensorLike]) -> tuple[int | None, ...]: | ||||||
| tensors = [ptb.as_tensor_variable(t) for t in tensors] | ||||||
| n_axes_before, n_axes_after, _, _ = self._analyze_axes_list() | ||||||
|
|
||||||
| def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None: | ||||||
| unique_shapes = {s for s in shapes if s is not None} | ||||||
| if not unique_shapes: | ||||||
| return None | ||||||
| if len(unique_shapes) > 1: | ||||||
| raise ValueError( | ||||||
| f"Input tensors to Pack op have incompatible sizes on dimension {axis} : {shapes}" | ||||||
| ) | ||||||
| return unique_shapes.pop() | ||||||
|
|
||||||
| shapes_to_pack = [ | ||||||
| t.type.shape[n_axes_before : t.ndim - n_axes_after] for t in tensors | ||||||
| ] | ||||||
| packed_shape = ( | ||||||
| None | ||||||
| if any( | ||||||
| shape is None | ||||||
| for packed_shape in shapes_to_pack | ||||||
| for shape in packed_shape | ||||||
| ) | ||||||
| else int(sum(np.prod(shapes) for shapes in shapes_to_pack)) | ||||||
| ) | ||||||
| prefix_shapes = [ | ||||||
| _coalesce_dim([t.type.shape[i] for t in tensors], i) | ||||||
| for i in range(n_axes_before) | ||||||
| ] | ||||||
| suffix_shapes = [ | ||||||
| _coalesce_dim( | ||||||
| [t.type.shape[t.ndim - n_axes_after + i] for t in tensors], | ||||||
| n_axes_before + i, | ||||||
| ) | ||||||
| for i in range(n_axes_after) | ||||||
| ] | ||||||
|
|
||||||
| return (*prefix_shapes, packed_shape, *suffix_shapes) | ||||||
|
|
||||||
|
|
||||||
| class Pack(OpFromGraph): | ||||||
| "Wrapper for the Pack Op" | ||||||
|
|
||||||
|
|
||||||
| def pack( | ||||||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't tell how this function works for the docstrings. What happens when I pass inputs with different dimensions, and single/list of axes? |
||||||
| *tensors: TensorVariable, axes: int | Sequence[int] | None = None | ||||||
| ) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: | ||||||
| """ | ||||||
| Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| tensors: TensorVariable | ||||||
| Tensors to be packed into a single vector. | ||||||
| axes: int or sequence of int, optional | ||||||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled | ||||||
| and joined. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| flat_tensor: TensorVariable | ||||||
| A new symbolic variable representing the concatenated 1d vector of all tensor inputs | ||||||
| packed_shapes: list of tuples of TensorVariable | ||||||
| A list of tuples, where each tuple contains the symbolic shape of the original tensors. | ||||||
| """ | ||||||
| if not tensors: | ||||||
| raise ValueError("Cannot pack an empty list of tensors.") | ||||||
|
|
||||||
| tensors = [ptb.as_tensor(tensor) for tensor in tensors] | ||||||
|
|
||||||
| pack_helper = PackHelper(axes=axes) | ||||||
|
|
||||||
| reshaped_tensors = [] | ||||||
| tmp_shapes = [] | ||||||
|
|
||||||
| n_axes_before, n_axes_after, _, _ = pack_helper._analyze_axes_list() | ||||||
| pack_helper.validate_inputs(tensors) | ||||||
| output_shape = pack_helper.infer_shape(tensors) | ||||||
|
|
||||||
| for i, tensor in enumerate(tensors): | ||||||
| shape = tensor.shape | ||||||
| ndim = tensor.ndim | ||||||
| axis_after_packed_axes = ndim - n_axes_after | ||||||
| tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes]) | ||||||
| reshaped_tensors.append( | ||||||
| tensor.reshape( | ||||||
| (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]) | ||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| packed_output_tensor = specify_shape( | ||||||
| ptb.join(n_axes_before, *reshaped_tensors), output_shape | ||||||
| ) | ||||||
| packed_output_shapes = [ | ||||||
| ptb.as_tensor_variable(packed_shape).astype("int64") | ||||||
| for i, packed_shape in enumerate(tmp_shapes) | ||||||
| ] | ||||||
|
|
||||||
| pack_op = Pack( | ||||||
| inputs=tensors, | ||||||
| outputs=[packed_output_tensor, *packed_output_shapes], | ||||||
| name="Pack{axes=" + str(axes) + "}", | ||||||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why give name instead of just defining the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because I looked at the docs for |
||||||
| ) | ||||||
|
|
||||||
| outputs = pack_op(*tensors) | ||||||
| return outputs[0], outputs[1:] | ||||||
|
|
||||||
|
|
||||||
| def unpack( | ||||||
| flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]] | ||||||
| ) -> tuple[TensorVariable, ...]: | ||||||
| """ | ||||||
| Unpack a flat tensor into its original shapes based on the provided packed shapes. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| flat_tensor: TensorVariable | ||||||
| A 1D tensor that contains the concatenated values of the original tensors. | ||||||
| packed_shapes: list of tuples of TensorVariable | ||||||
| A list of tuples, where each tuple contains the symbolic shape of the original tensors. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| unpacked_tensors: tuple of TensorVariable | ||||||
| A tuple containing the unpacked tensors with their original shapes. | ||||||
| """ | ||||||
| if not packed_shapes: | ||||||
| raise ValueError("Cannot unpack an empty list of shapes.") | ||||||
|
|
||||||
| n_splits = len(packed_shapes) | ||||||
| split_size = [prod(shape).astype(int) for shape in packed_shapes] | ||||||
| unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits) | ||||||
|
|
||||||
| return tuple( | ||||||
| [x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)] | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| __all__ = [ | ||||||
| "bartlett", | ||||||
| "bincount", | ||||||
|
|
@@ -2027,10 +2368,12 @@ def concat_with_broadcast(tensor_list, axis=0): | |||||
| "geomspace", | ||||||
| "linspace", | ||||||
| "logspace", | ||||||
| "pack", | ||||||
| "ravel_multi_index", | ||||||
| "repeat", | ||||||
| "searchsorted", | ||||||
| "squeeze", | ||||||
| "unique", | ||||||
| "unpack", | ||||||
| "unravel_index", | ||||||
| ] | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty gnarly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically you asked for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How so?