Skip to content

Commit 750bdb5

Browse files
hariprasadraviHariprasad RavishankarHariprasad Ravishankar
authored
[LINALG] Fix: Incorrect linalg lowering for aten.convolution_transpose with negative effective padding (#4369)
### **The Bug** The `torch-to-linalg` lowering for `aten.convolution` (with `transposed=true`) incorrectly handles cases where the effective padding is negative. The logic for this is contained in `createTransposedInputPadding`. The original implementation had two critical flaws: **Incorrect Math**: The logic block for negative padding (if (anyDimensionPaddingIsNegative)) attempted to "pre-crop" the input tensor before un-striding. The math used to calculate these slice offsets and sizes was incorrect, resulting in `tensor.extract_slice` operations with out-of-bounds offsets and negative sizes, causing the compiler to fail. **Failed "Mixed-Mode**" **Logic**: The code was built on an "all-or-nothing" assumption. It failed to handle "mixed-mode" padding, where one spatial dimension required padding (positive offset) while another required cropping (negative offset). It would enter the negative padding path and apply cropping logic to all dimensions, leading to out-of-bounds errors when it tried to crop a dimension that should have been padded. ### **The Fix** This patch refactors the logic into two clean, robust paths: **All-Padding Path (else block):** Trigger: All spatial dimensions have an effective padding offset >= 0. Action: Retains the original, efficient "fast path." It uses a single `tensor.insert_slice` to perform both un-striding (with strides) and padding (with offsets) in one operation. **Safe Path (if (anyDimensionPaddingIsNegative) block):** Trigger: At least one spatial dimension has a negative effective padding offset. Action: This path is now a unified, robust 3-step process that correctly handles both all-crop and mixed-mode scenarios: Create "Super-Tensor": It computes a maxSizes tensor, which is the "union" of the padded and un-strided sizes (i.e., max(innerSize, outerSize) for each dimension). Pad & Un-stride: It performs a single `tensor.insert_slice` of the original input into this maxSizes tensor. This one operation correctly applies all positive padding (via insertSliceOffsets) and un-striding (via strideIndexValues). Crop: It performs a final `tensor.extract_slice` to crop the maxSizes tensor down to the final outerSizes. This correctly applies all negative padding (via extractSliceOffsets). This new logic resolves all known failure cases and is validated by the new TransposedConv{1,2,3}dNegativePadding test cases, which specifically target this functionality. --------- Co-authored-by: Hariprasad Ravishankar <[email protected]> Co-authored-by: Hariprasad Ravishankar <[email protected]>
1 parent 8d563af commit 750bdb5

File tree

4 files changed

+225
-36
lines changed

4 files changed

+225
-36
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,25 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
15381538
};
15391539
} // namespace
15401540

1541+
/*
1542+
* Calculates the dimensions and offsets needed to emulate a Transposed
1543+
* Convolution (like PyTorch's ConvTranspose2d) using a standard
1544+
* Forward Convolution.
1545+
*
1546+
* This involves creating a new tensor by:
1547+
* 1. Calculating `innerSizes`: The input size after dilation by `stride`.
1548+
* innerSize[i] = (inDim[i] - 1) * stride[i] + 1
1549+
*
1550+
* 2. Calculating `outerSizes`: The final padded tensor size.
1551+
* offset[i] = (weightDim[i] - 1) * dilation[i] - padding[i]
1552+
* outerSize[i] = innerSize[i] + (2 * offset[i]) + outputPadding[i]
1553+
*
1554+
* If `offset[i]` is negative, this is treated as *cropping* the
1555+
* `innerSizes` tensor. This function calculates the
1556+
* `insertSliceOffsets` (padding) and `extractSliceOffsets` (cropping)
1557+
* to correctly place the (potentially cropped) inner tensor within the
1558+
* new outer tensor.
1559+
*/
15411560
Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15421561
Value inBatch, Value inChannels, SmallVector<Value> &inDims,
15431562
SmallVector<Value> &weightDims, SmallVector<Value> &paddingIntValues,
@@ -1551,33 +1570,34 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15511570
SmallVector<Value> insertSliceOffsets{c0, c0};
15521571

15531572
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
1554-
SmallVector<Value> sliceSizes{inputSizes[0], inputSizes[1]};
1555-
1556-
// For the case in which the padding dimension value is negative,
1557-
// we will need to shrink the dimension. Note in the PyTorch
1558-
// ConvTranspose2d operator documentation that the padding is
1559-
// defined by dilation * (kernel_size - 1) - padding. If the
1560-
// resulting padding is negative, PyTorch will extract elements
1561-
// from both sides of the dimension.
1573+
15621574
SmallVector<Value> extractSliceOffsets{c0, c0};
15631575
bool anyDimensionPaddingIsNegative = false;
15641576

15651577
Value c2 = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(2));
15661578

15671579
for (size_t i = 0; i < numSpatialDims; i++) {
1580+
// Calculate inner size: (input_size - 1) * stride + 1
15681581
Value innerSize = rewriter.createOrFold<arith::SubIOp>(loc, inDims[i], c1);
15691582
innerSize = rewriter.createOrFold<arith::MulIOp>(
15701583
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
15711584
innerSize = rewriter.createOrFold<arith::AddIOp>(loc, innerSize, c1);
1585+
innerSizes.push_back(innerSize);
15721586

15731587
Value offset = rewriter.createOrFold<arith::SubIOp>(loc, weightDims[i], c1);
15741588
offset = rewriter.createOrFold<arith::MulIOp>(
15751589
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
15761590
offset = rewriter.createOrFold<arith::SubIOp>(
15771591
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
15781592

1593+
// We need to crop or pad from two sides - top&bottom or left&right.
1594+
// Therefore multiply by 2.
15791595
Value outerSize = rewriter.createOrFold<arith::MulIOp>(loc, offset, c2);
1596+
1597+
// Crop or pad based on the sign of offset
15801598
outerSize = rewriter.createOrFold<arith::AddIOp>(loc, outerSize, innerSize);
1599+
1600+
// Add optional padding values
15811601
outerSize = rewriter.createOrFold<arith::AddIOp>(
15821602
loc, outerSize,
15831603
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
@@ -1587,45 +1607,76 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15871607
// Make the negative value positive by multiplying by -1.
15881608
anyDimensionPaddingIsNegative = true;
15891609
auto offsetType = offset.getType();
1590-
auto negOneConst = rewriter.createOrFold<arith::ConstantOp>(
1591-
loc, offsetType, rewriter.getIntegerAttr(offsetType, -1));
1610+
auto negOneConst = arith::ConstantOp::create(
1611+
rewriter, loc, rewriter.getIntegerAttr(offsetType, -1));
15921612
auto posOffset =
15931613
rewriter.createOrFold<arith::MulIOp>(loc, offset, negOneConst);
15941614

1595-
// Compute the reduced dimension size due to negative padding.
1596-
auto sizeReduction =
1597-
rewriter.createOrFold<arith::MulIOp>(loc, posOffset, c2);
1598-
sliceSizes.push_back(rewriter.createOrFold<arith::SubIOp>(
1599-
loc, inputSizes[i + 2], sizeReduction));
1600-
16011615
extractSliceOffsets.push_back(posOffset);
16021616
insertSliceOffsets.push_back(c0);
16031617
} else {
1604-
sliceSizes.push_back(inputSizes[i + 2]);
16051618
extractSliceOffsets.push_back(c0);
16061619
insertSliceOffsets.push_back(offset);
16071620
}
16081621
}
1609-
Value initTensor = createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
16101622

16111623
// Insert input into allocated tensor
16121624
SmallVector<Value> strideIndexValues{c1, c1};
16131625
for (auto stride : strideIntValues)
16141626
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
16151627

1616-
auto insertSliceOpInput = input;
16171628
if (anyDimensionPaddingIsNegative) {
1618-
insertSliceOpInput = tensor::ExtractSliceOp::create(
1629+
1630+
// Some dimensions may need padding and some dimensions need cropping
1631+
1632+
// 1. Allocate a maxSizes buffer (max of inner and outer for each dim)
1633+
// 2. Insert the input into maxSizes buffer at appropriate offsets (if
1634+
// insertSliceOffsets is positive, pad; 0 no padding) and stride
1635+
// 3. Extract the final outerSizes from maxSizes buffer
1636+
1637+
// Create the "max size" tensor to accommodate both padding and cropping
1638+
SmallVector<Value> maxSizes{inBatch, inChannels};
1639+
for (size_t i = 0; i < numSpatialDims; ++i) {
1640+
Value innerDim = innerSizes[i + 2];
1641+
Value outerDim = outerSizes[i + 2];
1642+
Value isPadding = rewriter.createOrFold<arith::CmpIOp>(
1643+
loc, arith::CmpIPredicate::ugt, outerDim, innerDim);
1644+
Value maxDim = rewriter.createOrFold<arith::SelectOp>(loc, isPadding,
1645+
outerDim, innerDim);
1646+
maxSizes.push_back(maxDim);
1647+
}
1648+
1649+
Value initMaxTensor =
1650+
createInitTensor(rewriter, loc, maxSizes, inputDTy, pad);
1651+
1652+
// Insert input
1653+
auto paddedTensor = tensor::InsertSliceOp::create(
16191654
rewriter, loc,
16201655
torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1621-
extractSliceOffsets, sliceSizes, strideIndexValues);
1622-
}
1656+
initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues);
16231657

1624-
auto paddedInput = tensor::InsertSliceOp::create(
1625-
rewriter, loc,
1626-
torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput),
1627-
initTensor, insertSliceOffsets, sliceSizes, strideIndexValues);
1628-
return paddedInput;
1658+
SmallVector<Value> allOnesStrides(inputSizes.size(), c1);
1659+
1660+
// Crop. Extract the final tensor from the "max" tensor
1661+
auto finalTensor = tensor::ExtractSliceOp::create(
1662+
rewriter, loc,
1663+
torch_to_linalg::removeSizeInformation(rewriter, loc, paddedTensor),
1664+
extractSliceOffsets, outerSizes, allOnesStrides);
1665+
1666+
return finalTensor;
1667+
1668+
} else {
1669+
1670+
Value initPaddedTensor =
1671+
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
1672+
1673+
// Insert the original input into the outer tensor with calculated offsets
1674+
auto paddedInput = tensor::InsertSliceOp::create(
1675+
rewriter, loc,
1676+
torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1677+
initPaddedTensor, insertSliceOffsets, inputSizes, strideIndexValues);
1678+
return paddedInput;
1679+
}
16291680
}
16301681

16311682
namespace {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,6 +3267,7 @@
32673267
"TraceSignedIntModule_basic",
32683268
"TraceUnsignedIntModule_basic",
32693269
"TraceUnsignedIntModule_empty",
3270+
"TransposedConv1dNegativePaddingUnitStrideDyn_basic",
32703271
"UniformModule_basic",
32713272
"UniformNoCorrelationModule_basic",
32723273
"UniformStaticShapeModule_basic",
@@ -3961,7 +3962,10 @@
39613962
"TraceModule_empty",
39623963
"TraceUnsignedIntModule_empty",
39633964
"TransposedConv1dNegativePadding_basic",
3965+
"TransposedConv1dNegativePaddingUnitStrideDyn_basic",
3966+
"TransposedConv1dNegativePaddingLarge_basic",
39643967
"TransposedConv2dNegativePadding_basic",
3968+
"TransposedConv2dPositiveAndNegativePadding_basic",
39653969
"TransposedConv3dNegativePadding_basic",
39663970
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
39673971
"InterpolateDynamicModule_sizes_nearest",
@@ -5039,7 +5043,10 @@
50395043
"TraceUnsignedIntModule_basic",
50405044
"TraceUnsignedIntModule_empty",
50415045
"TransposedConv1dNegativePadding_basic",
5046+
"TransposedConv1dNegativePaddingUnitStrideDyn_basic",
5047+
"TransposedConv1dNegativePaddingLarge_basic",
50425048
"TransposedConv2dNegativePadding_basic",
5049+
"TransposedConv2dPositiveAndNegativePadding_basic",
50435050
"TransposedConv3dNegativePadding_basic",
50445051
"TupleModule_basic",
50455052
"TypeAsDifferentModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,7 +1988,7 @@ def forward(self, inputVec, weight, bias):
19881988
inputVec,
19891989
weight,
19901990
bias=bias,
1991-
stride=[1],
1991+
stride=[4],
19921992
padding=[3],
19931993
dilation=[1],
19941994
transposed=True,
@@ -2002,6 +2002,72 @@ def TransposedConv1dNegativePadding_basic(module, tu: TestUtils):
20022002
module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2))
20032003

20042004

2005+
class TransposedConv1dNegativePaddingUnitStrideDyn(torch.nn.Module):
2006+
def __init__(self):
2007+
super().__init__()
2008+
2009+
@export
2010+
@annotate_args(
2011+
[
2012+
None,
2013+
([-1, -1, -1], torch.float32, True),
2014+
([1, 2, 3], torch.float32, True),
2015+
([2], torch.float32, True),
2016+
]
2017+
)
2018+
def forward(self, inputVec, weight, bias):
2019+
return torch.ops.aten.convolution(
2020+
inputVec,
2021+
weight,
2022+
bias=bias,
2023+
stride=[1],
2024+
padding=[3],
2025+
dilation=[1],
2026+
transposed=True,
2027+
output_padding=[0],
2028+
groups=1,
2029+
)
2030+
2031+
2032+
@register_test_case(
2033+
module_factory=lambda: TransposedConv1dNegativePaddingUnitStrideDyn()
2034+
)
2035+
def TransposedConv1dNegativePaddingUnitStrideDyn_basic(module, tu: TestUtils):
2036+
module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2))
2037+
2038+
2039+
class TransposedConv1dNegativePaddingLarge(torch.nn.Module):
2040+
def __init__(self):
2041+
super().__init__()
2042+
2043+
@export
2044+
@annotate_args(
2045+
[
2046+
None,
2047+
([1, 17, 5], torch.float32, True),
2048+
([17, 6, 3], torch.float32, True),
2049+
([6], torch.float32, True),
2050+
]
2051+
)
2052+
def forward(self, inputVec, weight, bias):
2053+
return torch.ops.aten.convolution(
2054+
inputVec,
2055+
weight,
2056+
bias=bias,
2057+
stride=[7],
2058+
padding=[10],
2059+
dilation=[4],
2060+
transposed=True,
2061+
output_padding=[0],
2062+
groups=1,
2063+
)
2064+
2065+
2066+
@register_test_case(module_factory=lambda: TransposedConv1dNegativePaddingLarge())
2067+
def TransposedConv1dNegativePaddingLarge_basic(module, tu: TestUtils):
2068+
module.forward(tu.rand(1, 17, 5), tu.rand(17, 6, 3), tu.rand(6))
2069+
2070+
20052071
class TransposedConv2dNegativePadding(torch.nn.Module):
20062072
def __init__(self):
20072073
super().__init__()
@@ -2034,6 +2100,38 @@ def TransposedConv2dNegativePadding_basic(module, tu: TestUtils):
20342100
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
20352101

20362102

2103+
class TransposedConv2dPositiveAndNegativePadding(torch.nn.Module):
2104+
def __init__(self):
2105+
super().__init__()
2106+
2107+
@export
2108+
@annotate_args(
2109+
[
2110+
None,
2111+
([1, 1, 4, 7], torch.float32, True),
2112+
([1, 2, 3, 3], torch.float32, True),
2113+
([2], torch.float32, True),
2114+
]
2115+
)
2116+
def forward(self, inputVec, weight, bias):
2117+
return torch.ops.aten.convolution(
2118+
inputVec,
2119+
weight,
2120+
bias=bias,
2121+
stride=[4, 4],
2122+
padding=[0, 3],
2123+
dilation=[1, 1],
2124+
transposed=True,
2125+
output_padding=[0, 0],
2126+
groups=1,
2127+
)
2128+
2129+
2130+
@register_test_case(module_factory=lambda: TransposedConv2dPositiveAndNegativePadding())
2131+
def TransposedConv2dPositiveAndNegativePadding_basic(module, tu: TestUtils):
2132+
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
2133+
2134+
20372135
class TransposedConv3dNegativePadding(torch.nn.Module):
20382136
def __init__(self):
20392137
super().__init__()
@@ -2052,9 +2150,9 @@ def forward(self, inputVec, weight, bias):
20522150
inputVec,
20532151
weight,
20542152
bias=bias,
2055-
stride=[1, 1, 1],
2153+
stride=[1, 5, 3],
20562154
padding=[2, 1, 3],
2057-
dilation=[1, 1, 1],
2155+
dilation=[1, 2, 1],
20582156
transposed=True,
20592157
output_padding=[0, 0, 0],
20602158
groups=1,

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,17 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>)
152152
}
153153

154154
// CHECK-LABEL: func.func @tranConv2dNegativePadding(
155-
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32>
156-
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
157-
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
158-
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32>
159-
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
160-
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
155+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> attributes {torch.assume_strict_symbolic_shapes} {
156+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
157+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
158+
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
159+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
160+
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x8x7xf32>
161+
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x8x7xf32>) -> tensor<1x1x8x7xf32>
162+
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x8x7xf32>
163+
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 8, 5] [1, 1, 1, 1] : tensor<1x1x8x7xf32> to tensor<1x1x8x5xf32>
164+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
165+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
161166
func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> !torch.vtensor<[1, 2, 6, 3],f32> attributes {torch.assume_strict_symbolic_shapes} {
162167
%int0 = torch.constant.int 0
163168
%true = torch.constant.bool true
@@ -174,3 +179,31 @@ func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) ->
174179
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1, 1, 4, 7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1, 2, 6, 3],f32>
175180
return %6 : !torch.vtensor<[1, 2, 6, 3],f32>
176181
}
182+
183+
// CHECK-LABEL: func.func @tranConv2dNegativeAndPositivePadding(
184+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>,
185+
// CHECK-SAME: %[[WEIGHTS_VTENSOR:.*]]: !torch.vtensor<[1,2,3,3],f32>,
186+
// CHECK-SAME: %[[BIAS_VTENSOR:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
187+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
188+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
189+
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
190+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
191+
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x17x25xf32>
192+
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x17x25xf32>) -> tensor<1x1x17x25xf32>
193+
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 4, 4] : tensor<1x1x4x7xf32> into tensor<1x1x17x25xf32>
194+
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 17, 23] [1, 1, 1, 1] : tensor<1x1x17x25xf32> to tensor<1x1x17x23xf32>
195+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x17x23xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x15x21xf32>) -> tensor<1x2x15x21xf32>
196+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x15x21xf32> -> !torch.vtensor<[1,2,15,21],f32>
197+
func.func @tranConv2dNegativeAndPositivePadding(%arg0: !torch.vtensor<[1,1,4,7],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
198+
%int1 = torch.constant.int 1
199+
%int3 = torch.constant.int 3
200+
%int0 = torch.constant.int 0
201+
%int4 = torch.constant.int 4
202+
%true = torch.constant.bool true
203+
%0 = torch.prim.ListConstruct %int4, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
204+
%1 = torch.prim.ListConstruct %int0, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
205+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
206+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
207+
%4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %true, %3, %int1 : !torch.vtensor<[1,1,4,7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,15,21],f32>
208+
return %4 : !torch.vtensor<[1,2,15,21],f32>
209+
}

0 commit comments

Comments
 (0)