@@ -1327,6 +1327,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13271327
13281328 SmallVector<int64_t > padding, strides, dilations;
13291329 SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations;
1330+ SmallVector<Value> paddingValues;
1331+
13301332 for (unsigned i = 0 ; i < rank - 2 ; i++) {
13311333 defaultPadding.push_back (0 );
13321334 defaultStrides.push_back (1 );
@@ -1360,36 +1362,88 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13601362 // at the beginning of axis i and xi_end, the number of pixels added at
13611363 // the end of axis i.
13621364 if (autoPad == " NOTSET" ) {
1363- if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
1365+ if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding))
13641366 return failure ();
1365- }
1367+
1368+ // Use the padding values
1369+ for (int64_t pad : padding)
1370+ paddingValues.push_back (rewriter.create <Torch::ConstantIntOp>(
1371+ loc, rewriter.getI64IntegerAttr (pad)));
13661372 } else if (autoPad == " VALID" ) {
1367- padding = defaultPadding;
1373+ for (int64_t pad : defaultPadding)
1374+ paddingValues.push_back (rewriter.create <Torch::ConstantIntOp>(
1375+ loc, rewriter.getI64IntegerAttr (pad)));
13681376 } else {
13691377 const bool isSameLower = autoPad == " SAME_LOWER" ;
13701378 const unsigned spatialRank = rank - 2 ;
1371- ArrayRef< int64_t > inputShape = inputTensorType. getSizes ( );
1372- padding. resize_for_overwrite ( 2 * spatialRank);
1379+ paddingValues. resize_for_overwrite ( 2 * spatialRank );
1380+
13731381 for (unsigned dimIdx = 0 ; dimIdx < spatialRank; dimIdx++) {
1374- if (weightShape[dimIdx + 2 ] == Torch::kUnknownSize ||
1375- inputShape[dimIdx + 2 ] == Torch::kUnknownSize )
1376- return rewriter.notifyMatchFailure (
1377- binder.op ,
1378- " expected weight and input tensor to have static shape" );
1379- const int64_t dilatedKernelSize =
1380- dilations[dimIdx] * (weightShape[dimIdx + 2 ] - 1 ) + 1 ;
1381- int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
1382- strides[dimIdx] -
1383- 1 ) *
1384- strides[dimIdx] +
1385- dilatedKernelSize - inputShape[dimIdx + 2 ];
1386- totalPad = totalPad >= 0 ? totalPad : 0 ;
1387- padding[dimIdx] =
1388- isSameLower ? ((totalPad + 1 ) / 2 ) : (totalPad / 2 );
1389- padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
1382+ // dilatedSize = dilations[dimIdx]*(weightShape[dimIdx + 2] - 1) + 1
1383+ Value cstOne = rewriter.create <Torch::ConstantIntOp>(
1384+ loc, rewriter.getI64IntegerAttr (1 ));
1385+ Value dilationValue = rewriter.create <Torch::ConstantIntOp>(
1386+ loc, rewriter.getI64IntegerAttr (dilations[dimIdx]));
1387+ Value weightDimSize =
1388+ Torch::getTensorDimSize (rewriter, weight, dimIdx + 2 );
1389+ Value weightMinusOne = rewriter.create <Torch::AtenSubIntOp>(
1390+ loc, weightDimSize, cstOne);
1391+ Value dilationMulWeight = rewriter.create <Torch::AtenMulIntOp>(
1392+ loc, dilationValue, weightMinusOne);
1393+ Value dilatedKernelSize = rewriter.create <Torch::AtenAddIntOp>(
1394+ loc, dilationMulWeight, cstOne);
1395+
1396+ // totalPad = (((inputShape[dimIdx + 2] + strides[dimIdx] -1) /
1397+ // strides[dimIdx]) - 1) * strides[dimIdx] +
1398+ // dilatedKernelSize - inputShape[dimIdx + 2];
1399+
1400+ Value stridesValue = rewriter.create <Torch::ConstantIntOp>(
1401+ loc, rewriter.getI64IntegerAttr (strides[dimIdx]));
1402+ Value inputDimSize =
1403+ Torch::getTensorDimSize (rewriter, input, dimIdx + 2 );
1404+ Value stridesMinusOne =
1405+ rewriter.create <Torch::AtenSubIntOp>(loc, stridesValue, cstOne);
1406+ Value inputStrides = rewriter.create <Torch::AtenAddIntOp>(
1407+ loc, inputDimSize, stridesMinusOne);
1408+ inputStrides = rewriter.create <Torch::AtenFloordivIntOp>(
1409+ loc, inputStrides, stridesValue);
1410+ inputStrides =
1411+ rewriter.create <Torch::AtenSubIntOp>(loc, inputStrides, cstOne);
1412+ inputStrides = rewriter.create <Torch::AtenMulIntOp>(
1413+ loc, inputStrides, stridesValue);
1414+ Value strideWithDilation = rewriter.create <Torch::AtenAddIntOp>(
1415+ loc, inputStrides, dilatedKernelSize);
1416+ Value totalPad = rewriter.create <Torch::AtenSubIntOp>(
1417+ loc, strideWithDilation, inputDimSize);
1418+
1419+ // totalPad = totalPad > 0 ? totalPad : 0;
1420+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1421+ loc, rewriter.getI64IntegerAttr (0 ));
1422+ totalPad =
1423+ rewriter.create <Torch::PrimMaxIntOp>(loc, totalPad, cstZero);
1424+
1425+ // padding[dimIdx] =
1426+ // isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1427+ // padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
1428+ Value cstTwo = rewriter.create <Torch::ConstantIntOp>(
1429+ loc, rewriter.getI64IntegerAttr (2 ));
1430+ if (isSameLower) {
1431+ auto padPlusOne =
1432+ rewriter.create <Torch::AtenAddIntOp>(loc, totalPad, cstOne);
1433+ paddingValues[dimIdx] = rewriter.create <Torch::AtenFloordivIntOp>(
1434+ loc, padPlusOne, cstTwo);
1435+ } else {
1436+ paddingValues[dimIdx] = rewriter.create <Torch::AtenFloordivIntOp>(
1437+ loc, totalPad, cstTwo);
1438+ }
1439+ paddingValues[spatialRank + dimIdx] =
1440+ rewriter.create <Torch::AtenSubIntOp>(loc, totalPad,
1441+ paddingValues[dimIdx]);
13901442 }
13911443 }
1392- if (padding.size () != rank - 2 && padding.size () != 2 * (rank - 2 )) {
1444+
1445+ if (paddingValues.size () != rank - 2 &&
1446+ paddingValues.size () != 2 * (rank - 2 )) {
13931447 return rewriter.notifyMatchFailure (
13941448 binder.op , " padding list size does not match the number of axes" );
13951449 }
@@ -1398,11 +1452,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13981452 cstOutputPadding;
13991453 Value paddedInput = input;
14001454 Value paddingList;
1401- if (padding.size () != 2 * (rank - 2 )) {
1402- for (int64_t i : padding) {
1403- cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
1404- loc, rewriter.getI64IntegerAttr (i)));
1405- }
1455+
1456+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1457+ loc, rewriter.getI64IntegerAttr (0 ));
1458+
1459+ if (paddingValues.size () != 2 * (rank - 2 )) {
1460+ cstPadding = paddingValues;
14061461 paddingList = rewriter.create <Torch::PrimListConstructOp>(
14071462 loc,
14081463 Torch::ListType::get (
@@ -1418,17 +1473,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14181473 // rightmost dim start and end, then next to last, and so on, e.g. {l,
14191474 // r, t, b}.
14201475 bool matchedPads = true ;
1421- for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
1422- if (padding[i] != padding[i + (padding.size () / 2 )]) {
1476+ for (unsigned i = 0 ; i < paddingValues.size () / 2 ; i++) {
1477+ int64_t padLow, padHigh;
1478+ if (!matchPattern (paddingValues[i],
1479+ Torch::m_TorchConstantInt (&padLow)) ||
1480+ !matchPattern (paddingValues[i + (paddingValues.size () / 2 )],
1481+ Torch::m_TorchConstantInt (&padHigh)) ||
1482+ padLow != padHigh) {
14231483 matchedPads = false ;
14241484 break ;
14251485 }
14261486 }
14271487 if (matchedPads) {
1428- for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
1429- cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
1430- loc, rewriter.getI64IntegerAttr (padding[i])));
1431- }
1488+ for (unsigned i = 0 ; i < paddingValues.size () / 2 ; i++)
1489+ cstPadding.push_back (paddingValues[i]);
14321490 paddingList = rewriter.create <Torch::PrimListConstructOp>(
14331491 loc,
14341492 Torch::ListType::get (
@@ -1437,16 +1495,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14371495 } else {
14381496 SmallVector<Value> padsRearrange;
14391497 SmallVector<Value> inputPaddingList;
1440- for (uint32_t i = 0 ; i < padding.size () / 2 ; i++) {
1441- padsRearrange.emplace_back (rewriter.create <Torch::ConstantIntOp>(
1442- loc, rewriter.getI64IntegerAttr (
1443- padding[padding.size () / 2 - i - 1 ])));
1444- padsRearrange.emplace_back (rewriter.create <Torch::ConstantIntOp>(
1445- loc,
1446- rewriter.getI64IntegerAttr (padding[padding.size () - i - 1 ])));
1447- inputPaddingList.emplace_back (
1448- rewriter.create <Torch::ConstantIntOp>(
1449- loc, rewriter.getI64IntegerAttr (0 )));
1498+ for (uint32_t i = 0 ; i < paddingValues.size () / 2 ; i++) {
1499+ padsRearrange.emplace_back (
1500+ paddingValues[paddingValues.size () / 2 - i - 1 ]);
1501+ padsRearrange.emplace_back (
1502+ (paddingValues[paddingValues.size () - i - 1 ]));
1503+ inputPaddingList.emplace_back (cstZero);
14501504 }
14511505 // The conv op itself will have no padding since the actual padding
14521506 // is performed using the torch.pad preceding it.
@@ -1468,23 +1522,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14681522 Value constantValue;
14691523
14701524 if (isa<IntegerType>(inputTensorType.getDtype ()))
1471- constantValue = rewriter.create <Torch::ConstantIntOp>(
1472- loc, rewriter.getI64IntegerAttr (0 ));
1525+ constantValue = cstZero;
14731526 if (isa<FloatType>(inputTensorType.getDtype ()))
14741527 constantValue = rewriter.create <Torch::ConstantFloatOp>(
14751528 loc, rewriter.getF64FloatAttr (0 .0f ));
1529+
1530+ auto getPadOutputSizeForInput = [&](int64_t low, int64_t high,
1531+ int64_t inputSize) {
1532+ int64_t padLow, padHigh;
1533+ if (inputSize == Torch::kUnknownSize ||
1534+ !matchPattern (paddingValues[low],
1535+ Torch::m_TorchConstantInt (&padLow)) ||
1536+ !matchPattern (paddingValues[high],
1537+ Torch::m_TorchConstantInt (&padHigh)))
1538+ return Torch::kUnknownSize ;
1539+ return inputSize + padLow + padHigh;
1540+ };
1541+
14761542 // Pad output shape must be computed explicitly from the pad values
1543+ // for static dims
14771544 SmallVector<int64_t > newInputShape (inputTensorType.getSizes ());
1478- for (uint32_t i = 0 ; i < padding .size () / 2 ; i++) {
1479- newInputShape[2 + i] +=
1480- padding[i] + padding[(padding .size () / 2 ) + i] ;
1545+ for (uint32_t i = 0 ; i < paddingValues .size () / 2 ; i++) {
1546+ newInputShape[2 + i] = getPadOutputSizeForInput (
1547+ i, (paddingValues .size () / 2 ) + i, newInputShape[ 2 + i]) ;
14811548 }
1549+
14821550 auto padTy = rewriter.getType <Torch::ValueTensorType>(
14831551 newInputShape, inputTensorType.getDtype ());
14841552 paddedInput = rewriter.create <Torch::AtenPadOp>(
14851553 loc, padTy, input, padsSizeList, modeVal, constantValue);
14861554 }
14871555 }
1556+
14881557 for (int64_t i : dilations) {
14891558 cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
14901559 loc, rewriter.getI64IntegerAttr (i)));
@@ -1493,8 +1562,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14931562 cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
14941563 loc, rewriter.getI64IntegerAttr (i)));
14951564 }
1496- Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1497- loc, rewriter.getI64IntegerAttr (0 ));
1565+
14981566 cstOutputPadding = {cstZero, cstZero};
14991567
15001568 Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
0 commit comments