@@ -1743,30 +1743,32 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
17431743
17441744 // Stride, padding, dilation for the backward conv. We only support constant
17451745 // lists here, consistent with forward convolution lowering.
1746- SmallVector<int64_t > strideInts;
1746+ SmallVector<Value> paddingIntValues;
1747+ SmallVector<int64_t > strideInts, dilationInts, outputPaddingInts;
1748+
17471749 if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
17481750 return rewriter.notifyMatchFailure (op,
17491751 " only support constant int strides" );
1750- SmallVector< int64_t > dilationInts;
1752+
17511753 if (!matchPattern (op.getDilation (),
17521754 m_TorchListOfConstantInts (dilationInts)))
17531755 return rewriter.notifyMatchFailure (op,
17541756 " only support constant int dilations" );
1755- SmallVector<Value> paddingIntValues, outputPaddingIntValues;
1757+
1758+ if (!matchPattern (op.getOutputPadding (),
1759+ m_TorchListOfConstantInts (outputPaddingInts)))
1760+ return rewriter.notifyMatchFailure (op,
1761+ " only support constant int output paddings" );
1762+ if (!llvm::all_of (outputPaddingInts,
1763+ [](int64_t outPad) { return outPad == 0 ; }))
1764+ return rewriter.notifyMatchFailure (
1765+ op, " unimplemented: only output padding of 0 supported." );
1766+
17561767 if (!getListConstructElements (op.getPadding (), paddingIntValues))
17571768 return rewriter.notifyMatchFailure (
17581769 op, " only support padding from a list construct" );
17591770 paddingIntValues = getTypeConvertedValues (rewriter, loc, getTypeConverter (),
17601771 paddingIntValues);
1761- if (!getListConstructElements (op.getOutputPadding (), outputPaddingIntValues))
1762- return rewriter.notifyMatchFailure (
1763- op, " only support output padding from a list construct" );
1764- outputPaddingIntValues = getTypeConvertedValues (rewriter, loc, getTypeConverter (),
1765- outputPaddingIntValues);
1766- // if (!llvm::all_of(outputPaddingIntValues,
1767- // [](int64_t outPad) { return outPad == 0; }))
1768- // return rewriter.notifyMatchFailure(
1769- // op, "unimplemented: only output padding of 0 supported.");
17701772
17711773 // The expandGroups lambda function below is used to expand the group
17721774 // dimension for weights and input, output tensors.
0 commit comments