@@ -6237,7 +6237,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
62376237 AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
62386238 SmallVectorImpl<int64_t > &dilationArray, Type &outputTy,
62396239 DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6240- DenseI64ArrayAttr &pad) {
6240+ DenseI64ArrayAttr &pad, SmallVectorImpl< int64_t > &explicitNHWCPad ) {
62416241
62426242 RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType ());
62436243 if (!inputTy)
@@ -6277,21 +6277,39 @@ static LogicalResult getOutputTypeAndPoolingParameters(
62776277
62786278 if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
62796279 std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6280- // Currently, we can not represent `count_include_pad` with the existing
6281- // TOSA AvgPool2d specification. Without the below check, we produce silent
6282- // wrong answer (SWA) when the `count_include_pad` value is `true.`
6283- //
6284- // Note: We need to check for `count_include_pad` only when the `padding`
6285- // value is non-zero.
6280+ // When count_include_pad=true with non-zero padding, we will materialize an
6281+ // explicit pad after transposing to NHWC. Track the padding extents and
6282+ // zero out the TOSA op padding so the divisor matches the full kernel size.
62866283 bool countIncludePad;
62876284 if ((paddingInts[0 ] != 0 || paddingInts[1 ] != 0 ) &&
62886285 (!matchPattern (op.getCountIncludePad (),
62896286 m_TorchConstantBool (&countIncludePad)) ||
62906287
62916288 countIncludePad)) {
6292- return rewriter.notifyMatchFailure (
6293- op, " Unsupported `count_include_pad` value, for tosa AvgPool "
6294- " `count_include_pad` value should be `False`." );
6289+ // Remember the spatial padding so we can emit an NHWC tosa.pad right
6290+ // after the transpose.
6291+ explicitNHWCPad.assign (
6292+ {paddingInts[0 ], paddingInts[0 ], paddingInts[1 ], paddingInts[1 ]});
6293+
6294+ auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
6295+ if (ShapedType::isDynamic (dim))
6296+ return ShapedType::kDynamic ;
6297+ return dim + before + after;
6298+ };
6299+
6300+ // Update the logical input type used for shape computations to include
6301+ // the extra zeros supplied by the explicit pad.
6302+ SmallVector<int64_t > paddedShape (inputTy.getShape ().begin (),
6303+ inputTy.getShape ().end ());
6304+ // Height stored at rank-2 and width at rank-1 while the tensor is still
6305+ // in NCHW order; the NHWC transpose happens later.
6306+ paddedShape[inputRank - 2 ] =
6307+ addPad (paddedShape[inputRank - 2 ], paddingInts[0 ], paddingInts[0 ]);
6308+ paddedShape[inputRank - 1 ] =
6309+ addPad (paddedShape[inputRank - 1 ], paddingInts[1 ], paddingInts[1 ]);
6310+ inputTy = RankedTensorType::get (paddedShape, inputTy.getElementType ());
6311+
6312+ paddingInts.assign (/* Count=*/ 2 , /* Value=*/ 0 );
62956313 }
62966314 }
62976315
@@ -6314,6 +6332,18 @@ static LogicalResult getOutputTypeAndPoolingParameters(
63146332 return success ();
63156333}
63166334
6335+ template <typename AtenOpT, typename tosaOp>
6336+ static LogicalResult getOutputTypeAndPoolingParameters (
6337+ AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
6338+ SmallVectorImpl<int64_t > &dilationArray, Type &outputTy,
6339+ DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6340+ DenseI64ArrayAttr &pad) {
6341+ SmallVector<int64_t , 4 > ignoredExplicitPad;
6342+ return getOutputTypeAndPoolingParameters<AtenOpT, tosaOp>(
6343+ op, rewriter, inputXchw, dilationArray, outputTy, kernel, stride, pad,
6344+ ignoredExplicitPad);
6345+ }
6346+
63176347class ConvertAtenMaxPool2dOp
63186348 : public ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp> {
63196349public:
@@ -6435,15 +6465,23 @@ class ConvertAtenAvgPool2dOp
64356465 }
64366466
64376467 SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6468+ SmallVector<int64_t , 4 > explicitNHWCPad;
64386469 if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
64396470 tosa::AvgPool2dOp>(
6440- op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6471+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
6472+ explicitNHWCPad)))
64416473 return rewriter.notifyMatchFailure (
64426474 op, " invalid pooling parameters or input type" );
64436475
6444- // Transpose to xHWC
6445- input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6446- transposePoolingInputToHwc (op, rewriter, self);
6476+ Value transposed =
6477+ ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6478+ transposePoolingInputToHwc (op, rewriter, self);
6479+
6480+ if (!explicitNHWCPad.empty ())
6481+ transposed = tosa::emitExplicitZeroPadNHWC (op->getLoc (), rewriter, op,
6482+ transposed, explicitNHWCPad);
6483+
6484+ input = transposed;
64476485
64486486 return success ();
64496487 }
@@ -6486,16 +6524,23 @@ class ConvertAtenAvgPool1dOp
64866524 .getResult ();
64876525
64886526 SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6527+ SmallVector<int64_t , 4 > explicitNHWCPad;
64896528 if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
64906529 tosa::AvgPool2dOp>(
64916530 op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6492- pad)))
6531+ pad, explicitNHWCPad )))
64936532 return rewriter.notifyMatchFailure (
64946533 op, " invalid pooling parameters or input type" );
64956534
6496- // Transpose to xHWC
6497- input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6498- transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6535+ Value transposed =
6536+ ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6537+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6538+
6539+ if (!explicitNHWCPad.empty ())
6540+ transposed = tosa::emitExplicitZeroPadNHWC (op->getLoc (), rewriter, op,
6541+ transposed, explicitNHWCPad);
6542+
6543+ input = transposed;
64996544
65006545 return success ();
65016546 }
0 commit comments