@@ -5766,19 +5766,69 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57665766 op, " Unimplemented pooling input parsing function" );
57675767 }
57685768
5769- static int64_t getOutputDim (int64_t inputDim, int64_t kernelDim,
5770- int64_t stride, int64_t padBefore,
5771- int64_t padAfter, int64_t dilation,
5769+ static int64_t getOutputDim (PatternRewriter &rewriter, Value &input,
5770+ Location loc, int64_t inputRank,
5771+ ArrayRef<int64_t > inputShape, Type inputElemTy,
5772+ int64_t dimIndex, int64_t kernelDim,
5773+ int64_t stride, int64_t &padBefore,
5774+ int64_t &padAfter, int64_t dilation,
57725775 bool ceilMode = false ) {
5776+ int64_t inputDim = inputShape[dimIndex];
57735777 if (inputDim == kUnknownSize ) {
57745778 return kUnknownSize ;
57755779 } else {
5780+ // TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5781+ // be fully divisible by stride. We would have to modify the after pad
5782+ // and/ input in order to achieve that.
5783+ // Note: The dimSize calculation below is the same as TOSA's dimSize
5784+ // calculation when dilation = 1, which is the only dilation value that
5785+ // TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5786+ // value will be defaulted to 1)
57765787 int64_t dimSize =
57775788 inputDim + padBefore + padAfter - dilation * (kernelDim - 1 ) - 1 ;
5789+ int64_t remainderDim = dimSize % stride;
5790+
5791+ // When PyTorch uses floor mode for output dim calculation, to achieve the
5792+ // TOSA's divisibility requirement, we will remove the unused after pad
5793+ // and slice the unused input rows/columns.
5794+ if (!ceilMode && (remainderDim != 0 )) {
5795+ if (remainderDim > padAfter) {
5796+ SmallVector<int64_t > startSlice (inputRank, 0 );
5797+ // In cases where we have to do 2 slice operations (one for height and
5798+ // one for width), we need to use the new sliced shape before doing
5799+ // the second slice, not the original inputShape. Therefore, the shape
5800+ // needs to be retrieved again here.
5801+ SmallVector<int64_t > sizeSlice (
5802+ dyn_cast<TensorType>(input.getType ()).getShape ());
5803+ sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5804+ input = rewriter.create <tosa::SliceOp>(
5805+ loc, RankedTensorType::get (sizeSlice, inputElemTy), input,
5806+ tosa::getTosaConstShape (rewriter, loc, startSlice),
5807+ tosa::getTosaConstShape (rewriter, loc, sizeSlice));
5808+ dimSize = dimSize - padAfter;
5809+ padAfter = 0 ;
5810+ } else {
5811+ dimSize = dimSize - padAfter;
5812+ padAfter = padAfter - remainderDim;
5813+ dimSize = dimSize + padAfter;
5814+ }
5815+ }
5816+
57785817 int64_t outputDim = dimSize / stride + 1 ;
5779- if (ceilMode && (dimSize % stride != 0 ) &&
5780- (outputDim * stride < inputDim + padBefore))
5781- outputDim++;
5818+
5819+ // When PyTorch uses ceil mode for output dim calculation, to achieve the
5820+ // TOSA's divisibility requirement, we will remove the unused after pad
5821+ // or add more after pad in case the remainder is more than the after pad
5822+ if (ceilMode && (remainderDim != 0 )) {
5823+ if (remainderDim < padAfter) {
5824+ padAfter = padAfter - remainderDim;
5825+ } else {
5826+ padAfter = padAfter + (stride - remainderDim);
5827+ }
5828+
5829+ if (outputDim * stride < inputDim + padBefore)
5830+ outputDim++;
5831+ }
57825832 return outputDim;
57835833 }
57845834 }
@@ -6016,25 +6066,24 @@ class ConvertAtenAdaptivePoolingOp
60166066
60176067template <typename AtenOpT, typename tosaOp>
60186068static Type getOutputTypeForNonAdaptivePoolingOp (
6069+ PatternRewriter &rewriter, Operation *op, Value &input,
60196070 RankedTensorType inputTy, SmallVectorImpl<int64_t > &kernelSize,
60206071 SmallVectorImpl<int64_t > &strideArray, SmallVectorImpl<int64_t > &padArray,
60216072 SmallVectorImpl<int64_t > &dilationArray, bool ceilMode = false ) {
60226073 auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
60236074 auto inputRank = inputTy.getRank ();
60246075 auto inputElemTy = inputTy.getElementType ();
60256076
6077+ // PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6078+ // rank-1
60266079 int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6027- inputShape[inputRank - 2 ], kernelSize[0 ], strideArray[0 ], padArray[0 ],
6028- padArray[0 ], dilationArray[0 ], ceilMode);
6080+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6081+ /* dimIndex=*/ inputRank - 2 , kernelSize[0 ], strideArray[0 ], padArray[0 ],
6082+ padArray[1 ], dilationArray[0 ], ceilMode);
60296083 int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6030- inputShape[inputRank - 1 ], kernelSize[1 ], strideArray[1 ], padArray[1 ],
6031- padArray[1 ], dilationArray[1 ], ceilMode);
6032- padArray[0 ] = (outputHDim - 1 ) * strideArray[0 ] +
6033- dilationArray[0 ] * kernelSize[0 ] - dilationArray[0 ] + 1 -
6034- padArray[0 ] * 2 - inputShape[inputRank - 2 ];
6035- padArray[1 ] = (outputWDim - 1 ) * strideArray[1 ] +
6036- dilationArray[0 ] * kernelSize[1 ] - dilationArray[0 ] + 1 -
6037- padArray[1 ] * 2 - inputShape[inputRank - 1 ];
6084+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6085+ /* dimIndex=*/ inputRank - 1 , kernelSize[1 ], strideArray[1 ], padArray[2 ],
6086+ padArray[3 ], dilationArray[1 ], ceilMode);
60386087 SmallVector<int64_t > outputShape;
60396088 if (inputRank > 3 )
60406089 outputShape.push_back (inputShape[0 ]);
@@ -6065,7 +6114,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> ¶ms,
60656114// vector. Also, gets the output type for the pooling op.
60666115template <typename AtenOpT, typename tosaOp>
60676116static LogicalResult getOutputTypeAndPoolingParameters (
6068- AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6117+ AtenOpT op, ConversionPatternRewriter &rewriter, Value & inputXchw,
60696118 SmallVectorImpl<int64_t > &dilationArray, Type &outputTy,
60706119 DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
60716120 DenseI64ArrayAttr &pad) {
@@ -6138,10 +6187,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61386187
61396188 expandPoolParams (op, dilationArray, 1 );
61406189 outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6141- inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6142- ceilMode);
6143- padArr[1 ] = padArr[1 ] + paddingInts[0 ];
6144- padArr[3 ] = padArr[3 ] + paddingInts[1 ];
6190+ rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6191+ dilationArray, ceilMode);
61456192 pad = rewriter.getDenseI64ArrayAttr (
61466193 {padArr[0 ], padArr[1 ], padArr[2 ], padArr[3 ]});
61476194 return success ();
@@ -6157,6 +6204,7 @@ class ConvertAtenMaxPool2dOp
61576204 DenseI64ArrayAttr &kernel,
61586205 DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
61596206 Type &outputTy) const override {
6207+ auto self = adaptor.getSelf ();
61606208 SmallVector<int64_t , 2 > dilationArray;
61616209 if (!matchPattern (op.getDilation (),
61626210 m_TorchListOfConstantInts (dilationArray)))
@@ -6169,14 +6217,13 @@ class ConvertAtenMaxPool2dOp
61696217
61706218 if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
61716219 tosa::MaxPool2dOp>(
6172- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6173- stride, pad)))
6220+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
61746221 return rewriter.notifyMatchFailure (
61756222 op, " invalid pooling parameters or input type" );
61766223
61776224 // Transpose to xHWC
61786225 input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6179- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6226+ transposePoolingInputToHwc (op, rewriter, self );
61806227
61816228 return success ();
61826229 }
@@ -6210,11 +6257,15 @@ class ConvertAtenMaxPool1dOp
62106257 // Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
62116258 SmallVector<int64_t > rank4Shape (selfShape);
62126259 rank4Shape.push_back (1 );
6213- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6214- op->getLoc (),
6215- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6216- selfTy.getElementType ()),
6217- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6260+ auto reshapedSelf =
6261+ rewriter
6262+ .create <tosa::ReshapeOp>(
6263+ op->getLoc (),
6264+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6265+ selfTy.getElementType ()),
6266+ self,
6267+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6268+ .getResult ();
62186269
62196270 SmallVector<int64_t > dilationArray;
62206271 if (!matchPattern (op.getDilation (),
@@ -6231,14 +6282,14 @@ class ConvertAtenMaxPool1dOp
62316282
62326283 if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
62336284 tosa::MaxPool2dOp>(
6234- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6235- kernel, stride, pad)))
6285+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6286+ pad)))
62366287 return rewriter.notifyMatchFailure (
62376288 op, " invalid pooling parameters or input type" );
62386289
62396290 // Transpose to xHWC
62406291 input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6241- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6292+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
62426293
62436294 return success ();
62446295 }
@@ -6254,6 +6305,7 @@ class ConvertAtenAvgPool2dOp
62546305 DenseI64ArrayAttr &kernel,
62556306 DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
62566307 Type &outputTy) const override {
6308+ auto self = adaptor.getSelf ();
62576309
62586310 // Currently, we can not represent `divisor_override` with the existing TOSA
62596311 // AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6267,14 +6319,13 @@ class ConvertAtenAvgPool2dOp
62676319 SmallVector<int64_t , 2 > dilationArray{1 , 1 };
62686320 if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
62696321 tosa::AvgPool2dOp>(
6270- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6271- stride, pad)))
6322+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
62726323 return rewriter.notifyMatchFailure (
62736324 op, " invalid pooling parameters or input type" );
62746325
62756326 // Transpose to xHWC
62766327 input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6277- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6328+ transposePoolingInputToHwc (op, rewriter, self );
62786329
62796330 return success ();
62806331 }
@@ -6308,23 +6359,27 @@ class ConvertAtenAvgPool1dOp
63086359 // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
63096360 SmallVector<int64_t > rank4Shape (selfShape);
63106361 rank4Shape.push_back (1 );
6311- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6312- op->getLoc (),
6313- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6314- selfTy.getElementType ()),
6315- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6362+ auto reshapedSelf =
6363+ rewriter
6364+ .create <tosa::ReshapeOp>(
6365+ op->getLoc (),
6366+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6367+ selfTy.getElementType ()),
6368+ self,
6369+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6370+ .getResult ();
63166371
63176372 SmallVector<int64_t , 2 > dilationArray{1 , 1 };
63186373 if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
63196374 tosa::AvgPool2dOp>(
6320- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6321- kernel, stride, pad)))
6375+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6376+ pad)))
63226377 return rewriter.notifyMatchFailure (
63236378 op, " invalid pooling parameters or input type" );
63246379
63256380 // Transpose to xHWC
63266381 input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6327- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6382+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
63286383
63296384 return success ();
63306385 }
0 commit comments