Skip to content

Commit 478acf7

Browse files
[torchlib] Fix unbind.int if num_outputs=1 (#2684)
This fixes the issue of ``` return [op.Squeeze(out, [dim]) for out in outputs] ^^^^^^^ TypeError: 'SymbolicTensor' object is not iterable ``` when trying to export LSTM modules in `torch`. This also already appeared in torch issues in pytorch/pytorch#126339 The core seems to be the changes in #2597. To my understanding the split returns a single `SymbolicTensor` instead of a sequence when `dim=1`. The fix implemented here is the casting of the return type to a list. I struggled with writing a test that reproduces this nicely in here, any guidance on that would be welcome. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 971f9bb commit 478acf7

File tree

1 file changed

+6
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+6
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8957,7 +8957,12 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
89578957
if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
89588958
# We can create a definitive split op if the input shape is static
89598959
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
8960-
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8960+
num_outputs = self.shape[dim]
8961+
if num_outputs != 1:
8962+
outputs = op.Split(self, axis=dim, num_outputs=num_outputs)
8963+
else:
8964+
outputs = [self]
8965+
89618966
return [op.Squeeze(out, [dim]) for out in outputs]
89628967

89638968
return op.SplitToSequence(self, axis=dim, keepdims=False)

0 commit comments

Comments
 (0)