Skip to content

Commit d0a84c4

Browse files
committed
2 parents c5bd2d2 + fe0e68f commit d0a84c4

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

Multigather.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ date: 2025-02-19
1111

1212
> *"One operator to gather them all, to bring them together, and in the darkness bind them."*
1313
14-
ML libraries have a confusing zoo of various gather/scatter operators, and it always takes me a few minutes to recall the brain-bending differences between every `gather*` variant out there, even just in ONNX ([Gather](https://onnx.ai/onnx/operators/onnx__Gather.html), [GatherElements](https://onnx.ai/onnx/operators/onnx__GatherElements.html), [GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)) let alone all the other ML libraries. Many are woefully underdocumentated on behavior too (e.g. [TOSA gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop) and [StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md)).
14+
ML libraries have a confusing mix of various gather/scatter operators, and it always takes me a few minutes to recall the brain-bending differences between every `gather*` variant out there, even just in ONNX ([Gather](https://onnx.ai/onnx/operators/onnx__Gather.html), [GatherElements](https://onnx.ai/onnx/operators/onnx__GatherElements.html), [GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)) let alone other ML libraries. Many are underdocumented too (e.g. [TOSA gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop) and [StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md)). Is there a more fundamental expression of a gathering operation that is more generic while also being simpler to document and implement?
1515

16-
It always bothered me after implementing `DML_OPERATOR_GATHER`, `DML_OPERATOR_GATHER_ELEMENTS`, and `DML_OPERATOR_GATHER_ND` (for the corresponding ONNX `Gather`, `GatherElements`, and `GatherND` operators) that there wasn't a more elegant DML operator to encompass them all at the *API level*, because at the GPU implementation level, every operator used the *same shader* after normalizing the tensor ranks/strides to be rank-compatible and broadcastable (which made the implementation much simpler and reusable). So surely there was a more general API form too hiding behind those differences, after some massaging:
16+
It always bothered me after implementing `DML_OPERATOR_GATHER`, `DML_OPERATOR_GATHER_ELEMENTS`, and `DML_OPERATOR_GATHER_ND` (for the corresponding ONNX operators) that there wasn't a more elegant DML operator to encompass them all at the *API level*, because at the GPU implementation level, every operator used the *same shader* after normalizing the tensor ranks/strides to be rank-compatible and broadcastable (which made the implementation much simpler and reusable). So after some massaging...
1717
- (1) set input and indices tensor ranks consistently, padding with 1's where needed
1818
- (2) pass `axes` explicitly (like how `reduce*` and `resample` take axes) instead of letting them be partially *inferred* from shapes
1919
- (3) use existing broadcasting definitions like those from elementwise operators
2020

21-
With those normalizations, you don't need to re-remember the divergences between each of them, nor need hacks like an extra `batch_dims` parameter.
21+
...you have one operator that can implement each, and you don't need to re-remember the divergences between each of them, nor need hacks like an extra `batch_dims` parameter.
2222

23-
## Equivalence Classes
23+
## Operator Equivalence Classes
2424

2525
Gather operators can be grouped so:
2626

@@ -30,7 +30,7 @@ Gather operators can be grouped so:
3030
| Single axis element gather (1D input absolute indices) | [PyTorch take](https://pytorch.org/docs/stable/generated/torch.take.html) | Same as above, but the input tensor is flattened to 1D first, meaning all indices are linearly unique to each input element.<br/><sup>`input.shape = [indexable dimension as 1D]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = implicitly 0`<br/>`output.shape = [index dimensions...]`<br/>`input.rank == 1 after flattening to 1D`<br/>`output.shape == indices.shape`<br/></sup>
3131
| Single axis block gather | [ONNX Gather](https://onnx.ai/onnx/operators/onnx__Gather.html)<br/>[numpy.take](https://numpy.org/doc/stable/reference/generated/numpy.take.html)<br/>[TensorFlow gather](https://www.tensorflow.org/api_docs/python/tf/gather)<br/>[CoreML gather](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather) | Dimensions are selected at a given axis and any trailing dimensions copy entire blocks to the output (as if those dimensions in indices were broadcast to the input.shape).<br/><sup>`input.shape = [leading dimensions..., input axis dimension, trailing dimensions...]`<br/>`indices.shape = [index dimensions...]`<br/>`axis = 0..(input.rank - 1)`<br/>`output.shape = [leading dimensions..., index dimensions..., trailing dimensions...]`<br/>`output.shape = input.shape[0..axis] ~ indices.shape ~ input.shape[axis+1..input.rank]`</sup>
3232
| Multiple contiguous axes block gather | [ONNX GatherND](https://onnx.ai/onnx/operators/onnx__GatherND.html)<br/>[ONNX gather_nd](https://www.tensorflow.org/api_docs/python/tf/gather_nd)<br/>[CoreML gather_nd](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather_nd) | Axes are indirectly implied by correspondence of input and indices shapes, the batch dimension count, and the size of the last dimension in indices (the lookup coordinate size). Axes start at dimension 0 in the input or after the batch dimension count if nonzero, and the number of indexable input dimensions depends on the coordinate size.<br/><sup>`input.shape = [batch dimensions..., indexable dimensions..., trailing dimensions...]`<br/>`indices.shape = [batch dimensions..., index dimensions..., coordinate size]`<br/>`batch dimension count < min(input.rank, indices.rank)`<br/>`output.shape = [batch dimensions..., index dimensions..., trailing dimensions...]`</sup>
33-
| Multiaxis gather | None known, emulatable via reshape + transpose + gatherND | Multiple noncontiguous axes are supported to gather from the input.<br/><sup>`input.shape = [mix of indexable and broadcastable dimensions...]`<br/>`indices.shape = [mix of index and broadcastable dimensions]`<br/>`output.shape = [mix of indexed and broadcasted dimensions]`<br/>`axes = [any unique dimensions < input.rank ...]`<br/>`broadcastShape = broadcast(input.shape, indices.shape)`<br/>`broadcastShape[axes[∀i]] = indices.shape[axes[∀i]]`<br/>`output.shape = broadcastShape`<br/>`input.rank == indices.rank == output.rank`</sup>
33+
| Multiaxis gather | None known, but emulatable via reshape + transpose + gatherND | Multiple noncontiguous axes are supported to gather from the input.<br/><sup>`input.shape = [mix of indexable and broadcastable dimensions...]`<br/>`indices.shape = [mix of index and broadcastable dimensions]`<br/>`output.shape = [mix of indexed and broadcasted dimensions]`<br/>`axes = [any unique dimensions < input.rank ...]`<br/>`broadcastShape = broadcast(input.shape, indices.shape)`<br/>`broadcastShape[axes[∀i]] = indices.shape[axes[∀i]]`<br/>`output.shape = broadcastShape`<br/>`input.rank == indices.rank == output.rank`</sup>
3434
| Indeterminate from documentation 🤷‍♂️ | [TOSA linalg gather](https://mlir.llvm.org/docs/Dialects/TOSA/#tosagather-mlirtosagatherop)<br/>[TOSA tensor gather](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorgather-tensorgatherop)<br/>[StableHLO gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md) | TOSA's gather is probably equivalent to one of the above, but the docs lack insight. StableHLO's gather looks quite complex, like some hybrid slice/gather chimera 😯 - it's out of scope.
3535

3636
They have the following properties:
@@ -39,7 +39,7 @@ They have the following properties:
3939
|----------------------------------------------------------------------|----------------|----------------|------------------|------------------|
4040
| Multiple axes |||||
4141
| Non-contiguous axes (like N and C in NHWC layout) |||||
42-
| Custom coordinate ordering (like [x,y] or [y,x]) |||||
42+
| Custom coordinate ordering (like \[x,y\] or \[y,x\]) |||||
4343
| Supports input < indices broadcasting before axes |||||
4444
| Supports indices < input broadcasting before axes¹ |||||
4545
| Supports indices < input broadcasting after axes |||||
@@ -51,8 +51,6 @@ They have the following properties:
5151

5252
## Multigather Operator API
5353

54-
This function implements the above:
55-
5654
```javascript
5755
partial interface MLGraphBuilder
5856
{

0 commit comments

Comments
 (0)