Skip to content

Commit 64e237c

Browse files
committed
added check for out paddings
1 parent f1d59b4 commit 64e237c

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,30 +1743,32 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
17431743

17441744
// Stride, padding, dilation for the backward conv. We only support constant
17451745
// lists here, consistent with forward convolution lowering.
1746-
SmallVector<int64_t> strideInts;
1746+
SmallVector<Value> paddingIntValues;
1747+
SmallVector<int64_t> strideInts, dilationInts, outputPaddingInts;
1748+
17471749
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
17481750
return rewriter.notifyMatchFailure(op,
17491751
"only support constant int strides");
1750-
SmallVector<int64_t> dilationInts;
1752+
17511753
if (!matchPattern(op.getDilation(),
17521754
m_TorchListOfConstantInts(dilationInts)))
17531755
return rewriter.notifyMatchFailure(op,
17541756
"only support constant int dilations");
1755-
SmallVector<Value> paddingIntValues, outputPaddingIntValues;
1757+
1758+
if (!matchPattern(op.getOutputPadding(),
1759+
m_TorchListOfConstantInts(outputPaddingInts)))
1760+
return rewriter.notifyMatchFailure(op,
1761+
"only support constant int output paddings");
1762+
if (!llvm::all_of(outputPaddingInts,
1763+
[](int64_t outPad) { return outPad == 0; }))
1764+
return rewriter.notifyMatchFailure(
1765+
op, "unimplemented: only output padding of 0 supported.");
1766+
17561767
if (!getListConstructElements(op.getPadding(), paddingIntValues))
17571768
return rewriter.notifyMatchFailure(
17581769
op, "only support padding from a list construct");
17591770
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
17601771
paddingIntValues);
1761-
if (!getListConstructElements(op.getOutputPadding(), outputPaddingIntValues))
1762-
return rewriter.notifyMatchFailure(
1763-
op, "only support output padding from a list construct");
1764-
outputPaddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
1765-
outputPaddingIntValues);
1766-
//if (!llvm::all_of(outputPaddingIntValues,
1767-
// [](int64_t outPad) { return outPad == 0; }))
1768-
// return rewriter.notifyMatchFailure(
1769-
// op, "unimplemented: only output padding of 0 supported.");
17701772

17711773
// The expandGroups lambda function below is used to expand the group
17721774
// dimension for weights and input, output tensors.

0 commit comments

Comments
 (0)