Skip to content

Commit 2fcd68a

Browse files
committed
[TOSA] Fix empty-dim reductions
Teach the TorchToTosa reducer that an explicit empty dim list means "all dims" and cast the result back to the requested dtype. Add MLIR and e2e regression cases and update XFAILs. Change-Id: Ibd1be38d219ad5c1986eb4a641efbb9ff0cb6a55
1 parent 8b77de9 commit 2fcd68a

File tree

5 files changed

+87
-3
lines changed

5 files changed

+87
-3
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,11 @@ class ConvertAtenMultipleDimsReductionOp
10891089
for (int64_t i = 0; i < inputRank; i++)
10901090
reduceDims.push_back(i);
10911091
}
1092+
// PyTorch treats an explicit empty list the same as "reduce all dims".
1093+
if (reduceDims.empty()) {
1094+
for (int64_t i = 0; i < inputRank; i++)
1095+
reduceDims.push_back(i);
1096+
}
10921097

10931098
int64_t N = reduceDims.size();
10941099
for (unsigned i = 0; i < N; i++) {

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,13 +782,23 @@ std::optional<Value> convertReduceOpCommon(
782782

783783
// Optionally squeeze out the reduced axes.
784784
if (!keep_dims) {
785+
auto squeezedType =
786+
RankedTensorType::get(output_shape, reduce_element_type);
785787
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
786-
rewriter, op->getLoc(), output_type, val,
788+
rewriter, op->getLoc(), squeezedType, val,
787789
tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape));
788790
val = reshape_op.getResult();
789791
}
790792
}
791793

794+
// Ensure the result element type matches the expected output type.
795+
if (val.getType() != output_type) {
796+
auto casted = tosa::tosaCastTensorToType(rewriter, val, output_type);
797+
if (!casted)
798+
return std::nullopt;
799+
val = casted.value();
800+
}
801+
792802
return val;
793803
}
794804

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3434,6 +3434,8 @@
34343434
"ElementwiseClampMinModule_bfloat16",
34353435
"ElementwiseClampModule_bfloat16",
34363436
"ElementwiseReluModule_bfloat16",
3437+
# torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '<class 'torch.TensorType'>'
3438+
"ReduceSumEmptyDimListInt8ToInt32Module_basic",
34373439
}
34383440

34393441
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3846,7 +3848,6 @@
38463848
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
38473849
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
38483850
"MaxPool3dWithIndicesStaticModule_basic",
3849-
"MeanDimEmptyDimModule_basic",
38503851
"MlGroupNormManualModule_basic",
38513852
"MlGroupNormModule_basic",
38523853
"MlLayerNormManualModule_basic",
@@ -3901,7 +3902,6 @@
39013902
"ReduceL3NormKeepDimComplexModule_basic",
39023903
"ReduceMaxAlongDimUnsignedInt_basic",
39033904
"ReduceMinAlongDimUnsignedInt_basic",
3904-
"ReduceSumDimIntListEmptyDimModule_basic",
39053905
"RollModule_basic",
39063906
"ScalarConstantTupleModule_basic",
39073907
"ScalarImplicitFloatModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,52 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
5858
# ==============================================================================
5959

6060

61+
class ReduceSumEmptyDimListInt8ToInt32Module(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
65+
@export
66+
@annotate_args(
67+
[
68+
None,
69+
([-1, -1, -1], torch.int8, True),
70+
]
71+
)
72+
def forward(self, a):
73+
return torch.sum(a, dim=[], dtype=torch.int32)
74+
75+
76+
@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8ToInt32Module())
77+
def ReduceSumEmptyDimListInt8ToInt32Module_basic(module, tu: TestUtils):
78+
module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8))
79+
80+
81+
# ==============================================================================
82+
83+
84+
class ReduceSumEmptyDimListInt8Module(torch.nn.Module):
85+
def __init__(self):
86+
super().__init__()
87+
88+
@export
89+
@annotate_args(
90+
[
91+
None,
92+
([-1, -1, -1], torch.int8, True),
93+
]
94+
)
95+
def forward(self, a):
96+
return torch.sum(a, dim=[])
97+
98+
99+
@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8Module())
100+
def ReduceSumEmptyDimListInt8Module_basic(module, tu: TestUtils):
101+
module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8))
102+
103+
104+
# ==============================================================================
105+
106+
61107
class ReduceSumElementTypeBoolModule(torch.nn.Module):
62108
def __init__(self):
63109
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,29 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> !
311311

312312
// -----
313313

314+
// CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic(
315+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> {
316+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
317+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
318+
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
319+
// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32>
320+
// CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32>
321+
// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32>
322+
// CHECK: %[[VAL_7:.*]] = tosa.const_shape
323+
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor<f32>
324+
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<f32> -> !torch.vtensor<[],f32>
325+
// CHECK: return %[[VAL_9]] : !torch.vtensor<[],f32>
326+
// CHECK: }
327+
func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> {
328+
%none = torch.constant.none
329+
%false = torch.constant.bool false
330+
%empty = torch.prim.ListConstruct : () -> !torch.list<int>
331+
%0 = torch.aten.sum.dim_IntList %arg0, %empty, %false, %none : !torch.vtensor<[2,3,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
332+
return %0 : !torch.vtensor<[],f32>
333+
}
334+
335+
// -----
336+
314337
// CHECK-LABEL: func.func @test_linalg_vector_norm$basic(
315338
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> {
316339
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>

0 commit comments

Comments
 (0)