Skip to content

Commit bc657db

Browse files
[TOSA] Add legalization for avg_pool with count_include_pad=True (#4273)
Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. --------- Signed-off-by: Vitalii Shutov <[email protected]>
1 parent e61b71a commit bc657db

File tree

5 files changed

+189
-59
lines changed

5 files changed

+189
-59
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
107107
Type inputElemTy, Type outputElemTy,
108108
ArrayRef<int64_t> weightShape);
109109

110+
// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
111+
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as
112+
// {top, bottom, left, right}. Returns the padded tensor value.
113+
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
114+
Operation *op, Value inputNHWC,
115+
ArrayRef<int64_t> padExtents);
116+
110117
} // namespace tosa
111118
} // namespace mlir
112119

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6237,7 +6237,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
62376237
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
62386238
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
62396239
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6240-
DenseI64ArrayAttr &pad) {
6240+
DenseI64ArrayAttr &pad, SmallVectorImpl<int64_t> &explicitNHWCPad) {
62416241

62426242
RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
62436243
if (!inputTy)
@@ -6277,21 +6277,39 @@ static LogicalResult getOutputTypeAndPoolingParameters(
62776277

62786278
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
62796279
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6280-
// Currently, we can not represent `count_include_pad` with the existing
6281-
// TOSA AvgPool2d specification. Without the below check, we produce silent
6282-
// wrong answer (SWA) when the `count_include_pad` value is `true.`
6283-
//
6284-
// Note: We need to check for `count_include_pad` only when the `padding`
6285-
// value is non-zero.
6280+
// When count_include_pad=true with non-zero padding, we will materialize an
6281+
// explicit pad after transposing to NHWC. Track the padding extents and
6282+
// zero out the TOSA op padding so the divisor matches the full kernel size.
62866283
bool countIncludePad;
62876284
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
62886285
(!matchPattern(op.getCountIncludePad(),
62896286
m_TorchConstantBool(&countIncludePad)) ||
62906287

62916288
countIncludePad)) {
6292-
return rewriter.notifyMatchFailure(
6293-
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6294-
"`count_include_pad` value should be `False`.");
6289+
// Remember the spatial padding so we can emit an NHWC tosa.pad right
6290+
// after the transpose.
6291+
explicitNHWCPad.assign(
6292+
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
6293+
6294+
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
6295+
if (ShapedType::isDynamic(dim))
6296+
return ShapedType::kDynamic;
6297+
return dim + before + after;
6298+
};
6299+
6300+
// Update the logical input type used for shape computations to include
6301+
// the extra zeros supplied by the explicit pad.
6302+
SmallVector<int64_t> paddedShape(inputTy.getShape().begin(),
6303+
inputTy.getShape().end());
6304+
// Height stored at rank-2 and width at rank-1 while the tensor is still
6305+
// in NCHW order; the NHWC transpose happens later.
6306+
paddedShape[inputRank - 2] =
6307+
addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]);
6308+
paddedShape[inputRank - 1] =
6309+
addPad(paddedShape[inputRank - 1], paddingInts[1], paddingInts[1]);
6310+
inputTy = RankedTensorType::get(paddedShape, inputTy.getElementType());
6311+
6312+
paddingInts.assign(/*Count=*/2, /*Value=*/0);
62956313
}
62966314
}
62976315

@@ -6314,6 +6332,18 @@ static LogicalResult getOutputTypeAndPoolingParameters(
63146332
return success();
63156333
}
63166334

6335+
template <typename AtenOpT, typename tosaOp>
6336+
static LogicalResult getOutputTypeAndPoolingParameters(
6337+
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
6338+
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
6339+
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6340+
DenseI64ArrayAttr &pad) {
6341+
SmallVector<int64_t, 4> ignoredExplicitPad;
6342+
return getOutputTypeAndPoolingParameters<AtenOpT, tosaOp>(
6343+
op, rewriter, inputXchw, dilationArray, outputTy, kernel, stride, pad,
6344+
ignoredExplicitPad);
6345+
}
6346+
63176347
class ConvertAtenMaxPool2dOp
63186348
: public ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp> {
63196349
public:
@@ -6435,15 +6465,23 @@ class ConvertAtenAvgPool2dOp
64356465
}
64366466

64376467
SmallVector<int64_t, 2> dilationArray{1, 1};
6468+
SmallVector<int64_t, 4> explicitNHWCPad;
64386469
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
64396470
tosa::AvgPool2dOp>(
6440-
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6471+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
6472+
explicitNHWCPad)))
64416473
return rewriter.notifyMatchFailure(
64426474
op, "invalid pooling parameters or input type");
64436475

6444-
// Transpose to xHWC
6445-
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6446-
transposePoolingInputToHwc(op, rewriter, self);
6476+
Value transposed =
6477+
ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6478+
transposePoolingInputToHwc(op, rewriter, self);
6479+
6480+
if (!explicitNHWCPad.empty())
6481+
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
6482+
transposed, explicitNHWCPad);
6483+
6484+
input = transposed;
64476485

64486486
return success();
64496487
}
@@ -6486,16 +6524,23 @@ class ConvertAtenAvgPool1dOp
64866524
.getResult();
64876525

64886526
SmallVector<int64_t, 2> dilationArray{1, 1};
6527+
SmallVector<int64_t, 4> explicitNHWCPad;
64896528
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
64906529
tosa::AvgPool2dOp>(
64916530
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6492-
pad)))
6531+
pad, explicitNHWCPad)))
64936532
return rewriter.notifyMatchFailure(
64946533
op, "invalid pooling parameters or input type");
64956534

6496-
// Transpose to xHWC
6497-
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6498-
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
6535+
Value transposed =
6536+
ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6537+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
6538+
6539+
if (!explicitNHWCPad.empty())
6540+
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
6541+
transposed, explicitNHWCPad);
6542+
6543+
input = transposed;
64996544

65006545
return success();
65016546
}

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,5 +624,45 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
624624
}
625625
}
626626

627+
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
628+
Operation *op, Value inputNHWC,
629+
ArrayRef<int64_t> padExtents) {
630+
assert(padExtents.size() == 4 && "expected [top, bottom, left, right]");
631+
632+
if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; }))
633+
return inputNHWC;
634+
635+
SmallVector<int64_t, 8> nhwcPadding = {
636+
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
637+
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);
638+
639+
auto inputTy = dyn_cast<RankedTensorType>(inputNHWC.getType());
640+
if (!inputTy)
641+
return inputNHWC;
642+
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
643+
inputTy.getShape().end());
644+
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
645+
if (ShapedType::isDynamic(dim))
646+
return ShapedType::kDynamic;
647+
return dim + before + after;
648+
};
649+
resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]);
650+
resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]);
651+
652+
auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType());
653+
654+
Type elemTy = inputTy.getElementType();
655+
Value padConst;
656+
if (isa<mlir::FloatType>(elemTy)) {
657+
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
658+
} else {
659+
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
660+
}
661+
662+
return tosa::PadOp::create(rewriter, loc, resultTy, inputNHWC, nhwcPadShape,
663+
padConst)
664+
.getResult();
665+
}
666+
627667
} // namespace tosa
628668
} // namespace mlir

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3533,7 +3533,6 @@
35333533
"AtenSymConstrainRangeForSize_basic",
35343534
"AtenSymConstrainRange_basic",
35353535
"Aten_AssertScalar_basic",
3536-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
35373536
"ScatterAddDynamicModule_basic",
35383537
"UniformModule_basic",
35393538
"UniformStaticShapeModule_basic",
@@ -3655,21 +3654,14 @@
36553654
"AtenTopKModule_basic",
36563655
"AtenTopKSmallestModule_basic",
36573656
"Aten_EmbeddingBagExample_basic",
3658-
"AvgPool1dFloatModule_basic",
36593657
"AvgPool1dIntModule_basic",
36603658
"AvgPool1dStaticModule_basic",
3661-
"AvgPool2dCeilModeTrueModule_basic",
36623659
"AvgPool1dNoPadCeilPadNotIncluded_basic",
36633660
"AvgPool1dPadCeilPadNotIncluded_basic",
3664-
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
3665-
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
3666-
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
36673661
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
36683662
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
36693663
"AvgPool2dDivisorOverrideModule_basic",
3670-
"AvgPool2dFloatModule_basic",
36713664
"AvgPool2dIntModule_basic",
3672-
"AvgPool2dStaticModule_basic",
36733665
"BernoulliFloatModule_basic",
36743666
"BernoulliPModule_basic",
36753667
"BernoulliTensorModule_basic",

0 commit comments

Comments
 (0)