Skip to content

Commit c5bd2d2

Browse files
committed
Multigather minor polish
1 parent ed4fdcb commit c5bd2d2

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

Multigather.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,26 @@ Gather operators can be grouped so:
2727
| Category | Library Names | Notes |
2828
|----------------------------------------------------------------------|---------------|--------------|
2929
| Single axis element gather | [ONNX GatherElements](https://onnx.ai/onnx/operators/onnx__GatherElements.html)<br/>[PyTorch gather](https://pytorch.org/docs/stable/generated/torch.gather.html)<br/>[PyTorch take_along_dim](https://pytorch.org/docs/stable/generated/torch.take_along_dim.html)<br/>[numpy.take_along_axis](https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html)<br/>[CoreML gather_along_axis](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather_along_axis) | All tensors have the same rank. All dimensions in input and indices have the same size except the active axis.<br/><sup>`input.shape = [leading dimensions..., input axis dimension, trailing dimensions...]`<br/>`indices.shape = [leading dimensions..., output axis dimension, trailing dimensions...]`<br/>`axis = 0..(input.rank - 1)`<br/>`output.shape = [leading dimensions..., output axis dimension, trailing dimensions...]`<br/>`output.shape[axis] == indices.shape[axis]`<br/>`input.rank == indices.rank == output.rank`</sup>
30-
| Single axis element gather 1D input | [PyTorch take](https://pytorch.org/docs/stable/generated/torch.take.html) | Same as above, but input tensors is flattened to 1D first.<br/><sup>`input.shape = [indexable dimension as 1D]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = implicitly 0`<br/>`output.shape = [index dimensions...]`<br/>`input.rank == 1 after flattening to 1D`<br/>`output.shape == indices.shape`<br/></sup>
30+
| Single axis element gather (1D input absolute indices) | [PyTorch take](https://pytorch.org/docs/stable/generated/torch.take.html) | Same as above, but the input tensor is flattened to 1D first, meaning all indices are linearly unique to each input element.<br/><sup>`input.shape = [indexable dimension as 1D]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = implicitly 0`<br/>`output.shape = [index dimensions...]`<br/>`input.rank == 1 after flattening to 1D`<br/>`output.shape == indices.shape`<br/></sup>
3131
| Single axis block gather | [ONNX Gather](https://onnx.ai/onnx/operators/onnx__Gather.html)<br/>[numpy.take](https://numpy.org/doc/stable/reference/generated/numpy.take.html)<br/>[TensorFlow gather](https://www.tensorflow.org/api_docs/python/tf/gather)<br/>[CoreML gather](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather) | Dimensions are selected at a given axis and any trailing dimensions copy entire blocks to the output (as if those dimensions in indices were broadcast to the input.shape).<br/><sup>`input.shape = [leading dimensions..., input axis dimension, trailing dimensions...]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = 0..(input.rank - 1)`<br/>`output.shape = [leading dimensions..., index dimensions..., trailing dimensions...]`<br/>`output.shape = input.shape[0..axis] ~ indices.shape ~ input.shape[axis+1..input.rank]`</sup>
3232
| Multiple contiguous axes block gather | [ONNX GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)<br/>[ONNX gather_nd](https://www.tensorflow.org/api_docs/python/tf/gather_nd)<br/>[CoreML gather_nd](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather_nd) | Axes are indirectly implied by correspondence of input and indices shapes, the batch dimension count, and the size of the last dimension in indices (the lookup coordinate size). Axes start at dimension 0 in the input or after the batch dimension count if nonzero, and the number of indexable input dimensions depends on the coordinate size.<br/><sup>`input.shape = [batch dimensions..., indexable dimensions..., trailing dimensions...]`<br/>`indices.shape = [batch dimensions..., index dimensions..., coordinate size]`<br/>`batch dimension count < min(input.rank, indices.rank)`<br/>`output.shape = [batch dimensions..., index dimensions..., trailing dimensions...]`</sup>
3333
| Multiaxis gather | None known, emulatable via reshape + transpose + gatherND | Multiple noncontiguous axes are supported to gather from the input.<br/><sup>`input.shape = [mix of indexable and broadcastable dimensions...]`<br/>`indices.shape = [mix of index and broadcastable dimensions]`<br/>`output.shape = [mix of indexed and broadcasted dimensions]`<br/>`axes = [any unique dimensions < input.rank ...]`<br/>`broadcastShape = broadcast(input.shape, indices.shape)`<br/>`broadcastShape[axes[∀i]] = indices.shape[axes[∀i]]`<br/>`output.shape = broadcastShape`<br/>`input.rank == indices.rank == output.rank`</sup>
3434
| Indeterminate from documentation 🤷‍♂️ | [TOSA linalg gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop)<br/>[TOSA tensor gather](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorgather-tensorgatherop)<br/>[StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md) | TOSA's gather is probably equivalent to one of the above, but the docs lack insight. StableHLO's gather looks quite complex, like some hybrid slice/gather chimera 😯 - it's out of scope.
3535

3636
They have the following properties:
3737

38-
| Category | GatherElements | Gather(blocks) | GatherND(blocks) | Gather Axes |
39-
|----------------------------------------------------------------------|----------------|----------------|------------------|-------------|
40-
| Multiple axes |||||
41-
| Non-contiguous axes (like N and C in NHWC layout) |||||
42-
| Custom coordinate ordering (like [x,y] or [y,x]) |||||
43-
| Supports input < indices broadcasting before axes |||||
44-
| Supports indices < input broadcasting before axes¹ |||||
45-
| Supports indices < input broadcasting after axes |||||
46-
| Supports trailing broadcasting (after axes) |||||
47-
| Trivial implementation² |||||
48-
49-
- ¹ Unsure if it's supposed to or not, but ORT 2024-11-26 crashes with a divizion by zero when trying.
38+
| Category | GatherElements | Gather(blocks) | GatherND(blocks) | Gather Multiaxis |
39+
|----------------------------------------------------------------------|----------------|----------------|------------------|------------------|
40+
| Multiple axes |||| |
41+
| Non-contiguous axes (like N and C in NHWC layout) |||| |
42+
| Custom coordinate ordering (like [x,y] or [y,x]) |||| |
43+
| Supports input < indices broadcasting before axes |||| |
44+
| Supports indices < input broadcasting before axes¹ |||| |
45+
| Supports indices < input broadcasting after axes |||| |
46+
| Supports trailing broadcasting (after axes) |||| |
47+
| Trivial implementation² |||| |
48+
49+
- ¹ Unsure if it's supposed to or not, but ORT 2024-11-26 [crashes](https://github.com/microsoft/onnxruntime/issues/23828) with a divizion by zero when trying.
5050
- ² Trivial implementations reduce the chances of bugs.
5151

5252
## Multigather Operator API

0 commit comments

Comments
 (0)