You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: Multigather.md
+6-8Lines changed: 6 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,16 +11,16 @@ date: 2025-02-19
11
11
12
12
> *"One operator to gather them all, to bring them together, and in the darkness bind them."*
13
13
14
-
ML libraries have a confusing zoo of various gather/scatter operators, and it always takes me a few minutes to recall the brain-bending differences between every `gather*` variant out there, even just in ONNX ([Gather](https://onnx.ai/onnx/operators/onnx__Gather.html), [GatherElements](https://onnx.ai/onnx/operators/onnx__GatherElements.html), [GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)) let alone all the other ML libraries. Many are woefully underdocumentated on behavior too (e.g. [TOSA gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop) and [StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md)).
14
+
ML libraries have a confusing mix of various gather/scatter operators, and it always takes me a few minutes to recall the brain-bending differences between every `gather*` variant out there, even just in ONNX ([Gather](https://onnx.ai/onnx/operators/onnx__Gather.html), [GatherElements](https://onnx.ai/onnx/operators/onnx__GatherElements.html), [GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)) let alone other ML libraries. Many are underdocumented too (e.g. [TOSA gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop) and [StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md)). Is there a more fundamental expression of a gathering operation that is more generic while also being simpler to document and implement?
15
15
16
-
It always bothered me after implementing `DML_OPERATOR_GATHER`, `DML_OPERATOR_GATHER_ELEMENTS`, and `DML_OPERATOR_GATHER_ND` (for the corresponding ONNX `Gather`, `GatherElements`, and `GatherND`operators) that there wasn't a more elegant DML operator to encompass them all at the *API level*, because at the GPU implementation level, every operator used the *same shader* after normalizing the tensor ranks/strides to be rank-compatible and broadcastable (which made the implementation much simpler and reusable). So surely there was a more general API form too hiding behind those differences, after some massaging:
16
+
It always bothered me after implementing `DML_OPERATOR_GATHER`, `DML_OPERATOR_GATHER_ELEMENTS`, and `DML_OPERATOR_GATHER_ND` (for the corresponding ONNX operators) that there wasn't a more elegant DML operator to encompass them all at the *API level*, because at the GPU implementation level, every operator used the *same shader* after normalizing the tensor ranks/strides to be rank-compatible and broadcastable (which made the implementation much simpler and reusable). So after some massaging...
17
17
- (1) set input and indices tensor ranks consistently, padding with 1's where needed
18
18
- (2) pass `axes` explicitly (like how `reduce*` and `resample` take axes) instead of letting them be partially *inferred* from shapes
19
19
- (3) use existing broadcasting definitions like those from elementwise operators
20
20
21
-
With those normalizations, you don't need to re-remember the divergences between each of them, nor need hacks like an extra `batch_dims` parameter.
21
+
...you have one operator that can implement each, and you don't need to re-remember the divergences between each of them, nor need hacks like an extra `batch_dims` parameter.
22
22
23
-
## Equivalence Classes
23
+
## Operator Equivalence Classes
24
24
25
25
Gather operators can be grouped so:
26
26
@@ -30,7 +30,7 @@ Gather operators can be grouped so:
30
30
| Single axis element gather (1D input absolute indices) | [PyTorch take](https://pytorch.org/docs/stable/generated/torch.take.html) | Same as above, but the input tensor is flattened to 1D first, meaning all indices are linearly unique to each input element.<br/><sup>`input.shape = [indexable dimension as 1D]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = implicitly 0`<br/>`output.shape = [index dimensions...]`<br/>`input.rank == 1 after flattening to 1D`<br/>`output.shape == indices.shape`<br/></sup>
31
31
| Single axis block gather | [ONNX Gather](https://onnx.ai/onnx/operators/onnx__Gather.html)<br/>[numpy.take](https://numpy.org/doc/stable/reference/generated/numpy.take.html)<br/>[TensorFlow gather](https://www.tensorflow.org/api_docs/python/tf/gather)<br/>[CoreML gather](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather) | Dimensions are selected at a given axis and any trailing dimensions copy entire blocks to the output (as if those dimensions in indices were broadcast to the input.shape).<br/><sup>`input.shape = [leading dimensions..., input axis dimension, trailing dimensions...]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = 0..(input.rank - 1)`<br/>`output.shape = [leading dimensions..., index dimensions..., trailing dimensions...]`<br/>`output.shape = input.shape[0..axis] ~ indices.shape ~ input.shape[axis+1..input.rank]`</sup>
32
32
| Multiple contiguous axes block gather | [ONNX GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)<br/>[ONNX gather_nd](https://www.tensorflow.org/api_docs/python/tf/gather_nd)<br/>[CoreML gather_nd](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather_nd) | Axes are indirectly implied by correspondence of input and indices shapes, the batch dimension count, and the size of the last dimension in indices (the lookup coordinate size). Axes start at dimension 0 in the input or after the batch dimension count if nonzero, and the number of indexable input dimensions depends on the coordinate size.<br/><sup>`input.shape = [batch dimensions..., indexable dimensions..., trailing dimensions...]`<br/>`indices.shape = [batch dimensions..., index dimensions..., coordinate size]`<br/>`batch dimension count < min(input.rank, indices.rank)`<br/>`output.shape = [batch dimensions..., index dimensions..., trailing dimensions...]`</sup>
33
-
| Multiaxis gather | None known, emulatable via reshape + transpose + gatherND | Multiple noncontiguous axes are supported to gather from the input.<br/><sup>`input.shape = [mix of indexable and broadcastable dimensions...]`<br/>`indices.shape = [mix of index and broadcastable dimensions]`<br/>`output.shape = [mix of indexed and broadcasted dimensions]`<br/>`axes = [any unique dimensions < input.rank ...]`<br/>`broadcastShape = broadcast(input.shape, indices.shape)`<br/>`broadcastShape[axes[∀i]] = indices.shape[axes[∀i]]`<br/>`output.shape = broadcastShape`<br/>`input.rank == indices.rank == output.rank`</sup>
33
+
| Multiaxis gather | None known, but emulatable via reshape + transpose + gatherND | Multiple noncontiguous axes are supported to gather from the input.<br/><sup>`input.shape = [mix of indexable and broadcastable dimensions...]`<br/>`indices.shape = [mix of index and broadcastable dimensions]`<br/>`output.shape = [mix of indexed and broadcasted dimensions]`<br/>`axes = [any unique dimensions < input.rank ...]`<br/>`broadcastShape = broadcast(input.shape, indices.shape)`<br/>`broadcastShape[axes[∀i]] = indices.shape[axes[∀i]]`<br/>`output.shape = broadcastShape`<br/>`input.rank == indices.rank == output.rank`</sup>
34
34
| Indeterminate from documentation 🤷♂️ | [TOSA linalg gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop)<br/>[TOSA tensor gather](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorgather-tensorgatherop)<br/>[StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md) | TOSA's gather is probably equivalent to one of the above, but the docs lack insight. StableHLO's gather looks quite complex, like some hybrid slice/gather chimera 😯 - it's out of scope.
35
35
36
36
They have the following properties:
@@ -39,7 +39,7 @@ They have the following properties:
0 commit comments