diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9c3da6405d59..cc39e06fe8b5 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1679,6 +1679,631 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( } } +namespace { +class ConvertAtenConvolutionBackwardOp + : public OpConversionPattern { + using IT = utils::IteratorType; + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenConvolutionBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value gradOutput = adaptor.getGradOutput(); + Value input = adaptor.getInput(); + Value weight = adaptor.getWeight(); + + auto gradOutputDTy = + cast(gradOutput.getType()).getElementType(); + auto inputDTy = cast(input.getType()).getElementType(); + auto weightDTy = cast(weight.getType()).getElementType(); + if (!isa(gradOutputDTy) || + !isa(inputDTy) || !isa(weightDTy)) + return op.emitError("unimplemented: only fp convolution bwd supported"); + + size_t gradRank = cast(gradOutput.getType()).getRank(); + size_t numSpatialDims = gradRank - 2; + if (numSpatialDims < 1 || numSpatialDims > 3) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1d-3d convolution bwd currently supported"); + + // Transposed convolution backward is not handled here yet. + bool transposed = false; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "only support constant bool for transposed"); + if (transposed) + return rewriter.notifyMatchFailure( + op, "unimplemented: transposed convolution backward"); + + // The `outMask` contains 3 boolean values for the results `grad_input`, + // `grad_weight`, and `grad_bias` respectively. The value being `false` + // means that the corresponding result will be none. + SmallVector outMask; + if (!matchPattern(op.getOutputMask(), + m_TorchListOfConstantBools(outMask)) || + outMask.size() != 3) + return rewriter.notifyMatchFailure( + op, "only constant bool output_mask list of size 3 is supported."); + for (unsigned i = 0; i < outMask.size(); i++) { + if (outMask[i] == false) { + Value result = op->getResults()[i]; + if (!result.getUsers().empty()) + return rewriter.notifyMatchFailure( + op, "unimplemented: false value supported for output_mask only " + "when the result tensor corresponding to that has no users."); + } + } + + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + bool isGroupedConvBwd = numGroups > 1; + int64_t spatialStartDimIdx = isGroupedConvBwd ? 3 : 2; + + // Stride, padding, dilation for the backward conv. We only support constant + // lists here, consistent with forward convolution lowering. + SmallVector paddingIntValues; + SmallVector strideInts, dilationInts, outputPaddingInts; + + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int strides"); + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int dilations"); + if (!matchPattern(op.getOutputPadding(), + m_TorchListOfConstantInts(outputPaddingInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int output paddings"); + if (!llvm::all_of(outputPaddingInts, + [](int64_t outPad) { return outPad == 0; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only output padding of 0 supported."); + + if (!getListConstructElements(op.getPadding(), paddingIntValues)) + return rewriter.notifyMatchFailure( + op, "only support padding from a list construct"); + paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + paddingIntValues); + + // The expandGroups lambda function below is used to expand the group + // dimension for weights and input, output tensors. + // For input tensor (dim = 1) : N,C,H,W -> N,G,C/G,H,W + // For grad_output tensor (dim = 1): N,F,H,W -> N,G,F/G,H,W + // For weight tensor (dim = 0) : F,C,H,W -> G,F/G,C,H,W + auto expandGroups = [&](Value tensor, int64_t dim) { + auto inType = cast(tensor.getType()); + auto inShape = makeShapeTorchCompatible(inType.getShape()); + + SmallVector outShape; + for (auto i = 0; i < static_cast(inShape.size()); i++) { + if (i == dim) { + outShape.push_back(numGroups); + outShape.push_back(inShape[i] == kUnknownSize + ? kUnknownSize + : inShape[i] / numGroups); + } else { + outShape.push_back(inShape[i]); + } + } + + SmallVector indices; + for (auto i = 0; i <= static_cast(inShape.size()); i++) { + if (i == dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, + indices); + }; + + SmallVector newResults(op->getNumResults()); + + // Computing Backward-Input Convolution. + if (outMask[0]) { + // If convolution bwd is grouped, `grad_output` should be expanded. + auto gradOutputExpanded = + isGroupedConvBwd ? expandGroups(gradOutput, 1) : gradOutput; + // If convolution bwd is grouped, `weight` should be expanded + auto weightExpanded = isGroupedConvBwd ? expandGroups(weight, 0) : weight; + + // Flip weight along spatial dims only if number of spatial dims > 1. + SmallVector weightFlipDims; + weightFlipDims.reserve(numSpatialDims); + for (int64_t i = 0; i < static_cast(numSpatialDims); ++i) + weightFlipDims.push_back(spatialStartDimIdx + i); + weightExpanded = torch_to_linalg::flipTensor( + rewriter, loc, weightExpanded, weightFlipDims); + + // For backward-input, padding must be adjusted to: + // p'[i] = d[i] * (K[i] - 1) - p[i] + Value c1 = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + SmallVector dilationIntValues = + getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector weiSizes = + getTensorSizes(rewriter, loc, weightExpanded); + SmallVector paddingValues(numSpatialDims); + for (size_t i = 0; i < numSpatialDims; ++i) { + Value kSize = + castIndexToInt64(rewriter, loc, weiSizes[spatialStartDimIdx + i]); + Value kMinusOne = rewriter.createOrFold(loc, kSize, c1); + Value mul = rewriter.createOrFold(loc, kMinusOne, + dilationIntValues[i]); + paddingValues[i] = + arith::SubIOp::create(rewriter, loc, mul, paddingIntValues[i]); + + if (isValueNegative(paddingValues[i])) + return rewriter.notifyMatchFailure( + op, "unimplemented: negative padding values are not supported."); + } + + // If there are not unit strides, we have to scatter `grad_output` into a + // zero-initialized tensor. + SmallVector gradInputSizes = getTensorSizes(rewriter, loc, input); + Value gradOutputSliced; + if (llvm::any_of(strideInts, [](int64_t stride) { return stride > 1; })) { + // Destination spatial sizes are computed as: + // size[i] = (D[i] - 1) + d[i] * (K[i] - 1) + 1 + // Offsets on spatial dims are paddings + // Strides on spatial dims are the original stride[i]. + Value zero = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); + Value one = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + // Initialize slice strides, sizes and offsets + SmallVector goSizes = + getTensorSizes(rewriter, loc, gradOutputExpanded); + SmallVector sizes(goSizes.begin(), + goSizes.begin() + spatialStartDimIdx); + SmallVector offsets(spatialStartDimIdx, zero); + SmallVector strides(spatialStartDimIdx, one); + for (size_t i = 0; i < numSpatialDims; ++i) { + // Shapes of `grad_input` has not been expanded yet + // if it's needed for group conv even + Value h = gradInputSizes[2 + i]; + Value k = weiSizes[spatialStartDimIdx + i]; + Value hMinusOne = rewriter.createOrFold(loc, h, one); + Value kMinusOne = rewriter.createOrFold(loc, k, one); + Value mul = rewriter.createOrFold( + loc, castIntToIndex(rewriter, loc, dilationIntValues[i]), + kMinusOne); + Value sum = rewriter.createOrFold(loc, hMinusOne, mul); + sizes.push_back(rewriter.createOrFold(loc, sum, one)); + offsets.push_back(castIntToIndex(rewriter, loc, paddingValues[i])); + + Value strideIntValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(strideInts[i])); + strides.push_back(castIntToIndex(rewriter, loc, strideIntValue)); + } + + Value zeroInit = + createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); + gradOutputSliced = tensor::InsertSliceOp::create( + rewriter, loc, + torch_to_linalg::removeSizeInformation(rewriter, loc, + gradOutputExpanded), + zeroInit, offsets, goSizes, strides); + } else { + // If there unit strides, pad `grad_output` spatial dims with zeros. + // If conv is grouped, output has shape: + // N x G x F/G x . Otherwise: N x F x . + Value padVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(gradOutputDTy, 0.0)); + gradOutputSliced = torch_to_linalg::getDynamicZeroPaddedTensor( + op, rewriter, gradOutputExpanded, paddingValues, spatialStartDimIdx, + padVal); + } + + // Initialize output buffer. For grouped, compute into an expanded + // [N, G, C/G, D*] tensor and collapse back to the original input shape. + Value gradInputInit = + createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy); + SmallVector gradInputCollapseIndices; + if (isGroupedConvBwd) { + auto gradInputInitExpanded = expandGroups(gradInputInit, 1); + gradInputInit = gradInputInitExpanded.getResult(); + gradInputCollapseIndices = + gradInputInitExpanded.getReassociationIndices(); + } + + // Generate GenericOp. + SmallVector indexingMaps; + SmallVector iteratorTypes; + initIndexingMapsAndIteratorTypesForDataBwd( + rewriter, context, isGroupedConvBwd, numSpatialDims, dilationInts, + indexingMaps, iteratorTypes); + auto genericRes = + createGenericOp(rewriter, loc, gradOutputSliced, weightExpanded, + gradInputInit, indexingMaps, iteratorTypes) + .getResult(0); + + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the generic op + // if it is grouped. + if (isGroupedConvBwd) { + genericRes = tensor::CollapseShapeOp::create( + rewriter, loc, input.getType(), genericRes, + gradInputCollapseIndices); + } + + // Cast to the final result type expected by the type converter. + newResults[0] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(0).getType()), + genericRes) + .getResult(); + } + + // Computing Backward-Weight Convolution. + if (outMask[1]) { + // If convolution bwd is grouped, `grad_output` should be expanded. + auto gradOutputExpanded = + isGroupedConvBwd ? expandGroups(gradOutput, 1) : gradOutput; + // If convolution bwd is grouped, `input` should be expanded + auto inputExpanded = isGroupedConvBwd ? expandGroups(input, 1) : input; + + // Pad input spatial dims with zeros. If grouped, input has shape: + // N x G x C/G x . Otherwise: N x C x . + // We should only pad the spatial dims, so set unpaddedDims accordingly. + Value padVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(inputDTy, 0.0)); + Value paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( + op, rewriter, inputExpanded, paddingIntValues, spatialStartDimIdx, + padVal); + + // Initialize output buffer. For grouped, compute into an expanded + // [G, F/G, C/G, K*] tensor and collapse back to the original weight + // shape. + SmallVector gradWeightSizes = + getTensorSizes(rewriter, loc, weight); + Value gradWeightInit = + createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); + SmallVector gradWeightCollapseIndices; + if (isGroupedConvBwd) { + auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); + gradWeightInit = gradWeightInitExpanded.getResult(); + gradWeightCollapseIndices = + gradWeightInitExpanded.getReassociationIndices(); + } + + // Generate GenericOp. + SmallVector indexingMaps; + SmallVector iteratorTypes; + initIndexingMapsAndIteratorTypesForWeightBwd( + rewriter, context, isGroupedConvBwd, numSpatialDims, strideInts, + dilationInts, indexingMaps, iteratorTypes); + auto genericRes = + createGenericOp(rewriter, loc, paddedInput, gradOutputExpanded, + gradWeightInit, indexingMaps, iteratorTypes) + .getResult(0); + + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the generic op + // if it is grouped. + if (isGroupedConvBwd) { + genericRes = tensor::CollapseShapeOp::create( + rewriter, loc, weight.getType(), genericRes, + gradWeightCollapseIndices); + } + + // Cast to the final result type expected by the type converter. + newResults[1] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(1).getType()), + genericRes) + .getResult(); + } + + // Computing Backward-Bias Convolution. + if (outMask[2]) { + // Sum grad_output along all dims except F using linalg. + DenseSet reduceDims; + reduceDims.insert(0); + for (int64_t i = 2; i < static_cast(gradRank); ++i) + reduceDims.insert(i); + + torch_to_linalg::ReductionOpInfo opInfo{false, gradOutput, reduceDims}; + + // Zero init for the element type (arith.constant expects a scalar attr). + Value initSum = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(gradOutputDTy)); + + auto reductionBody = [&](OpBuilder &b, Location loc, ValueRange args) { + Value x = args[0]; + Value acc = args[1]; + Value sum = arith::AddFOp::create(b, loc, x, acc); + linalg::YieldOp::create(b, loc, sum); + }; + + Value gradBias = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, opInfo, initSum, reductionBody); + + newResults[2] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(2).getType()), + gradBias) + .getResult(); + } + + rewriter.replaceOp(op, newResults); + + return success(); + } + +private: + static void initIndexingMapsAndIteratorTypesForDataBwd( + OpBuilder &rewriter, MLIRContext *context, bool isGrouped, + int numSpatialDims, const SmallVector &dilationInts, + SmallVector &indexingMaps, SmallVector &iteratorTypes) { + // To calculate convolution backward-data, we use generic operation. + // The generic operation is a generalization of the convolution operation + // that can handle any number of spatial dimensions. + // The generic operation is defined as follows: + // ``` + // dLdx[n, g, c, o] = sum(dLdy[n, g, f, d0 * k + o] * w[g, f, c, k] + // for n in range(batch_size) for o in range(in_spatial_dims)) + // ``` + // where `n` is the batch dimension, `g` is the group dimension, + // `c` is the input channel dimension, `f` is the output channel + // dimension, `o` is the input spatial dimension, `k` is the kernel + // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the + // gradient of the output tensor. `dLdx` is the data-gradient tensor. + if (!isGrouped) { + if (numSpatialDims == 1) { + AffineExpr n, c, o, f, k; + bindDims(context, n, c, o, f, k); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector goExprs = {n, f, d0 * k + o}; + SmallVector weiExprs = {f, c, k}; + SmallVector outExprs = {n, c, o}; + indexingMaps = {AffineMap::get(5, 0, goExprs, context), + AffineMap::get(5, 0, weiExprs, context), + AffineMap::get(5, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr n, c, oh, ow, f, kh, kw; + bindDims(context, n, c, oh, ow, f, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; + SmallVector weiExprs = {f, c, kh, kw}; + SmallVector outExprs = {n, c, oh, ow}; + indexingMaps = {AffineMap::get(7, 0, goExprs, context), + AffineMap::get(7, 0, weiExprs, context), + AffineMap::get(7, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction, + IT::reduction}; + } else { + AffineExpr n, c, od, oh, ow, f, kd, kh, kw; + bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector goExprs = {n, f, d0 * kd + od, d1 * kh + oh, + d2 * kw + ow}; + SmallVector weiExprs = {f, c, kd, kh, kw}; + SmallVector outExprs = {n, c, od, oh, ow}; + indexingMaps = {AffineMap::get(9, 0, goExprs, context), + AffineMap::get(9, 0, weiExprs, context), + AffineMap::get(9, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction, IT::reduction}; + } + } else { + if (numSpatialDims == 1) { + AffineExpr n, g, cg, o, fg, k; + bindDims(context, n, g, cg, o, fg, k); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector goExprs = {n, g, fg, d0 * k + o}; + SmallVector weiExprs = {g, fg, cg, k}; + SmallVector outExprs = {n, g, cg, o}; + indexingMaps = {AffineMap::get(6, 0, goExprs, context), + AffineMap::get(6, 0, weiExprs, context), + AffineMap::get(6, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr n, g, cg, oh, ow, fg, kh, kw; + bindDims(context, n, g, cg, oh, ow, fg, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector goExprs = {n, g, fg, d0 * kh + oh, + d1 * kw + ow}; + SmallVector weiExprs = {g, fg, cg, kh, kw}; + SmallVector outExprs = {n, g, cg, oh, ow}; + indexingMaps = {AffineMap::get(8, 0, goExprs, context), + AffineMap::get(8, 0, weiExprs, context), + AffineMap::get(8, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction}; + } else { + AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; + bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector goExprs = { + n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; + SmallVector weiExprs = {g, fg, cg, kd, kh, kw}; + SmallVector outExprs = {n, g, cg, od, oh, ow}; + indexingMaps = {AffineMap::get(10, 0, goExprs, context), + AffineMap::get(10, 0, weiExprs, context), + AffineMap::get(10, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction, IT::reduction, + IT::reduction}; + } + } + } + + static void initIndexingMapsAndIteratorTypesForWeightBwd( + OpBuilder &rewriter, MLIRContext *context, bool isGrouped, + int numSpatialDims, const SmallVector &strideInts, + const SmallVector &dilationInts, + SmallVector &indexingMaps, SmallVector &iteratorTypes) { + // To calculate convolution backward-weight, we use generic operation. + // The generic operation is a generalization of the convolution operation + // that can handle any number of spatial dimensions. + // The generic operation is defined as follows: + // ``` + // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] + // for n in range(batch_size) for o in range(output_spatial_dims)) + // ``` + // where `n` is the batch dimension, `g` is the group dimension, + // `c` is the input channel dimension, `f` is the output channel + // dimension, `o` is the output spatial dimension, `k` is the kernel + // dimension, `d0` is dilation and `s0` is stride. `x` is the input + // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the + // weight-gradient tensor. + if (!isGrouped) { + if (numSpatialDims == 1) { + AffineExpr f, c, k, n, o; + bindDims(context, f, c, k, n, o); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector inExprs = {n, c, d0 * k + s0 * o}; + SmallVector goExprs = {n, f, o}; + SmallVector outExprs = {f, c, k}; + indexingMaps = {AffineMap::get(5, 0, inExprs, context), + AffineMap::get(5, 0, goExprs, context), + AffineMap::get(5, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr f, c, kh, kw, n, oh, ow; + bindDims(context, f, c, kh, kw, n, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector inExprs = {n, c, d0 * kh + s0 * oh, + d1 * kw + s1 * ow}; + SmallVector goExprs = {n, f, oh, ow}; + SmallVector outExprs = {f, c, kh, kw}; + indexingMaps = {AffineMap::get(7, 0, inExprs, context), + AffineMap::get(7, 0, goExprs, context), + AffineMap::get(7, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction, + IT::reduction}; + } else { + AffineExpr f, c, kd, kh, kw, n, od, oh, ow; + bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector inExprs = { + n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; + SmallVector goExprs = {n, f, od, oh, ow}; + SmallVector outExprs = {f, c, kd, kh, kw}; + indexingMaps = {AffineMap::get(9, 0, inExprs, context), + AffineMap::get(9, 0, goExprs, context), + AffineMap::get(9, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction, IT::reduction}; + } + } else { + if (numSpatialDims == 1) { + AffineExpr g, fg, cg, k, n, o; + bindDims(context, g, fg, cg, k, n, o); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector inExprs = {n, g, cg, d0 * k + s0 * o}; + SmallVector goExprs = {n, g, fg, o}; + SmallVector outExprs = {g, fg, cg, k}; + indexingMaps = {AffineMap::get(6, 0, inExprs, context), + AffineMap::get(6, 0, goExprs, context), + AffineMap::get(6, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr g, fg, cg, kh, kw, n, oh, ow; + bindDims(context, g, fg, cg, kh, kw, n, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector inExprs = {n, g, cg, d0 * kh + s0 * oh, + d1 * kw + s1 * ow}; + SmallVector goExprs = {n, g, fg, oh, ow}; + SmallVector outExprs = {g, fg, cg, kh, kw}; + indexingMaps = {AffineMap::get(8, 0, inExprs, context), + AffineMap::get(8, 0, goExprs, context), + AffineMap::get(8, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction}; + } else { + AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; + bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector inExprs = { + n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; + SmallVector goExprs = {n, g, fg, od, oh, ow}; + SmallVector outExprs = {g, fg, cg, kd, kh, kw}; + indexingMaps = {AffineMap::get(10, 0, inExprs, context), + AffineMap::get(10, 0, goExprs, context), + AffineMap::get(10, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction, IT::reduction, + IT::reduction}; + } + } + } + + static linalg::GenericOp + createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, + const SmallVector &indexingMaps, + const SmallVector &iteratorTypes) { + return linalg::GenericOp::create( + b, loc, out.getType(), ValueRange{in0, in1}, out, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value grad = args[1]; + Value output = args[2]; + + // Convert input and grad to accumulator type if needed + Type accType = output.getType(); + if (input.getType() != accType) { + input = arith::ExtFOp::create(b, loc, accType, input); + } + if (grad.getType() != accType) { + grad = arith::ExtFOp::create(b, loc, accType, grad); + } + + Value mul = arith::MulFOp::create(b, loc, input, grad); + Value sum = arith::AddFOp::create(b, loc, mul, output); + linalg::YieldOp::create(b, loc, sum); + }); + } +}; +} // namespace + namespace { /// Creates coefficients based on DFT definition, see @@ -1880,6 +2505,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c9c42b43c463..37e5939b05d4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5831,406 +5831,6 @@ class DecomposeAtenConvTranspose3dOp }; } // namespace -// The convolution backward op is decomposed as follows: -// inputH, inputW = input.shape[2:] -// output_padding_ = [ -// inputH -// - 1 -// + 2 * padding_[0] -// - dilation_[0] * (weight.shape[2] - 1) -// - (grad_output.shape[2] - 1) * stride_[0], -// inputW -// - 1 -// + 2 * padding_[1] -// - dilation_[1] * (weight.shape[3] - 1) -// - (grad_output.shape[3] - 1) * stride_[1], -// ] -// -// decomp_grad_input = torch.nn.functional.conv_transpose2d( -// grad_output, -// weight, -// None, -// stride_, -// padding_, -// output_padding_, -// groups_, -// dilation_, -// ) -// -// input_transposed = torch.ops.aten.transpose(input, 0, 1) -// grad_output_transposed = grad_output.view( -// grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] -// ) -// decomp_grad_weight = torch.ops.aten.convolution( -// input_transposed, -// grad_output_transposed, -// bias=None, -// stride=dilation_, -// padding=padding_, -// dilation=stride_, -// transposed=False, -// output_padding=[0, 0], -// groups=input.shape[0], -// ) -// decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2]) -// decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3]) -// decomp_grad_weight = decomp_grad_weight.view( -// input_transposed.shape[0], -// input_transposed.shape[1], -// grad_output.shape[1], -// *decomp_grad_weight.shape[2:] -// ) -// decomp_grad_weight = decomp_grad_weight.movedim(0, 2) -// decomp_grad_weight = decomp_grad_weight.sum(dim=0) -// -// decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3]) -namespace { -class DecomposeAtenConvolutionBackwardOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op, - PatternRewriter &rewriter) const override { - - Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); - Value input = op.getInput(); - Value weight = op.getWeight(); - Value gradOutput = op.getGradOutput(); - std::optional maybeGradRank = getTensorRank(gradOutput); - if (!maybeGradRank) { - return rewriter.notifyMatchFailure(op, - "expected grad output to have a rank"); - } - unsigned gradRank = *maybeGradRank; - if (gradRank != 4) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D convolutions supported."); - - Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(0)); - Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(1)); - Value cstTwo = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(2)); - Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); - Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, - rewriter.getBoolAttr(false)); - - SmallVector padding, dilation, stride; - SmallVector paddingInt, dilationInt, strideInt, - outputPaddingInt; - - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt))) - return rewriter.notifyMatchFailure( - op, "padding must be a list of constant ints"); - - if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt))) - return rewriter.notifyMatchFailure( - op, "stride must be a list of constant ints"); - - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt))) - return rewriter.notifyMatchFailure( - op, "dilation must be a list of constant ints"); - if (!llvm::all_of(dilationInt, - [](int64_t dilationVal) { return dilationVal == 1; })) - return rewriter.notifyMatchFailure( - op, "unimplemented: only dilations of 1 supported."); - - if (!matchPattern(op.getOutputPadding(), - m_TorchListOfConstantInts(outputPaddingInt))) - return rewriter.notifyMatchFailure( - op, "output padding must be a list of constant ints"); - if (!llvm::all_of(outputPaddingInt, - [](int64_t outPad) { return outPad == 0; })) - return rewriter.notifyMatchFailure( - op, "unimplemented: only output padding of 0 supported."); - - // The `outMask` contains 3 boolean values for the results `grad_input`, - // `grad_weight`, and `grad_bias` respectively. The value being `false` - // means that the corresponding result will be none. - SmallVector outMask; - if (!matchPattern(op.getOutputMask(), - m_TorchListOfConstantBools(outMask)) || - outMask.size() != 3) - return rewriter.notifyMatchFailure( - op, "only constant bool output_mask list of size 3 is supported."); - for (unsigned i = 0; i < outMask.size(); i++) { - if (outMask[i] == false) { - Value result = op->getResults()[i]; - if (!result.getUsers().empty()) - return rewriter.notifyMatchFailure( - op, "unimplemented: false value supported for output_mask only " - "when the result tensor corresponding to that has no users."); - } - } - - bool transposed; - if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) - return rewriter.notifyMatchFailure( - op, "transposed arg should be a constant bool."); - if (transposed) - return rewriter.notifyMatchFailure( - op, "unimplemented: transposed convolutions are not supported."); - - Value gradInput = cstNone; - if (outMask[0]) { - // Computing Grad Input. - getListConstructElements(op.getPadding(), padding); - getListConstructElements(op.getStride(), stride); - getListConstructElements(op.getDilation(), dilation); - - // Calculate output padding for first convolution. - // output_padding_ = [ - // inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2] - // - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 * - // padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) - - // ((grad_out.size()[3] - 1) * stride_[1]), - // ] - SmallVector outputPaddingValues; - for (unsigned i = 2; i < gradRank; i++) { - Value dim = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(i)); - Value inputVecDim = - Torch::AtenSizeIntOp::create(rewriter, loc, input, dim); - Value gradOutDim = - Torch::AtenSizeIntOp::create(rewriter, loc, gradOutput, dim); - Value weightDim = - Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); - Value inputVecDimMinusOne = - Torch::AtenSubIntOp::create(rewriter, loc, inputVecDim, cstOne); - Value gradOutDimMinusOne = - Torch::AtenSubIntOp::create(rewriter, loc, gradOutDim, cstOne); - Value weightDimMinusOne = - Torch::AtenSubIntOp::create(rewriter, loc, weightDim, cstOne); - Value twoTimesPadding = - Torch::AtenMulIntOp::create(rewriter, loc, padding[i - 2], cstTwo); - Value tmpA = Torch::AtenMulIntOp::create( - rewriter, loc, weightDimMinusOne, dilation[i - 2]); - Value tmpB = Torch::AtenMulIntOp::create( - rewriter, loc, gradOutDimMinusOne, stride[i - 2]); - Value outputPaddingVal = AtenAddIntOp::create( - rewriter, loc, inputVecDimMinusOne, twoTimesPadding); - outputPaddingVal = - AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpA); - outputPaddingVal = - AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpB); - outputPaddingValues.push_back(outputPaddingVal); - } - Value outputPaddingForGradInput = Torch::PrimListConstructOp::create( - rewriter, loc, ListType::get(IntType::get(context)), - outputPaddingValues); - gradInput = Torch::AtenConvTranspose2dInputOp::create( - rewriter, loc, op.getResultTypes()[0], gradOutput, weight, cstNone, - op.getStride(), op.getPadding(), outputPaddingForGradInput, - op.getGroups(), op.getDilation()); - } - - Value gradWeight = cstNone; - if (outMask[1]) { - // Computing Grad Weight. - Type transposedType; - if (failed(getTransposedType(cast(input.getType()), 0, 1, - transposedType))) - return failure(); - Value inputTransposed = Torch::AtenTransposeIntOp::create( - rewriter, loc, transposedType, input, cstZero, cstOne); - - // For the cases where the stride is non-unit, we compute the `GradWeight` - // through this implementation. - if (!llvm::all_of(strideInt, - [](int64_t stride) { return stride == 1; })) { - SmallVector gradOutputSize; - for (unsigned i = 0; i < gradRank; i++) { - gradOutputSize.push_back(Torch::AtenSizeIntOp::create( - rewriter, loc, gradOutput, - Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(i)))); - } - - Value gradOutputViewDimZero = Torch::AtenMulIntOp::create( - rewriter, loc, gradOutputSize[0], gradOutputSize[1]); - Value gradOutputViewShapeList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], - gradOutputSize[3]}); - - BaseTensorType gradOutputTy = - cast(gradOutput.getType()); - if (!gradOutputTy.hasSizes()) - return failure(); - SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); - SmallVector gradOutputViewSizesInt(gradOutputSizesInt); - if (gradOutputViewSizesInt[0] != kUnknownSize && - gradOutputViewSizesInt[1] != kUnknownSize) - gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1]; - else - gradOutputViewSizesInt[0] = kUnknownSize; - gradOutputViewSizesInt[1] = 1; - BaseTensorType gradOutputTypeForView = - cast(gradOutputTy.getWithSizesAndDtype( - llvm::ArrayRef(gradOutputViewSizesInt), - gradOutputTy.getOptionalDtype())); - Value gradOutputView = - Torch::AtenViewOp::create(rewriter, loc, gradOutputTypeForView, - gradOutput, gradOutputViewShapeList); - - BaseTensorType inputTransposedTy = - cast(inputTransposed.getType()); - if (!inputTransposedTy.hasSizes()) - return failure(); - SmallVector inputTransposedSizesInt( - inputTransposedTy.getSizes()); - SmallVector gradWeightSizesInt{inputTransposedSizesInt[0], - gradOutputViewSizesInt[0]}; - for (unsigned i = 2; i < gradRank; i++) { - if (inputTransposedSizesInt[i] != kUnknownSize && - gradOutputViewSizesInt[i] != kUnknownSize) { - int64_t kernelSizeInt = - strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1; - gradWeightSizesInt.push_back( - ((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) - - kernelSizeInt) / - dilationInt[i - 2]) + - 1); - } else { - gradWeightSizesInt.push_back(kUnknownSize); - } - } - - BaseTensorType gradWeightTy = - cast(inputTransposedTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightSizesInt), - inputTransposedTy.getOptionalDtype())); - - Value numGroup = AtenSizeIntOp::create(rewriter, loc, input, cstZero); - gradWeight = Torch::AtenConvolutionOp::create( - rewriter, loc, gradWeightTy, inputTransposed, gradOutputView, - cstNone, - /*stride=*/op.getDilation(), op.getPadding(), - /*dilation=*/op.getStride(), op.getTransposed(), - op.getOutputPadding(), numGroup); - - BaseTensorType weightTy = cast(weight.getType()); - if (!weightTy.hasSizes()) - return failure(); - SmallVector weightSizes(weightTy.getSizes()); - for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { - gradWeightSizesInt[i + 2] = weightSizes[i + 2]; - BaseTensorType gradWeightNarrowTy = - cast(gradWeightTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightSizesInt), - gradWeightTy.getOptionalDtype())); - - Value dim = ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(i + 2)); - Value length = - Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); - gradWeight = Torch::AtenNarrowOp::create( - rewriter, loc, gradWeightNarrowTy, gradWeight, dim, - /*start=*/cstZero, length); - } - - SmallVector gradWeightViewShapeInt{ - inputTransposedSizesInt[0], inputTransposedSizesInt[1]}; - gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]); - gradWeightViewShapeInt.insert( - gradWeightViewShapeInt.end(), - {gradWeightSizesInt[2], gradWeightSizesInt[3]}); - - SmallVector gradWeightViewShapeValue; - for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { - gradWeightViewShapeValue.push_back(Torch::ConstantIntOp::create( - rewriter, loc, - rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); - } - - Value gradWeightViewShapeList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - gradWeightViewShapeValue); - - BaseTensorType gradWeightTypeForView = - cast(gradWeightTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightViewShapeInt), - gradWeightTy.getOptionalDtype())); - gradWeight = - Torch::AtenViewOp::create(rewriter, loc, gradWeightTypeForView, - gradWeight, gradWeightViewShapeList); - - gradWeightTy = cast(gradWeight.getType()); - SmallVector gradWeightDimsOrder = - computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); - SmallVector gradWeightMoveDimShape; - for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) { - gradWeightMoveDimShape.push_back( - gradWeightViewShapeInt[gradWeightDimsOrder[i]]); - } - BaseTensorType gradWeightTypeForMoveDim = - cast(gradWeightTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightMoveDimShape), - gradWeightTy.getOptionalDtype())); - - gradWeight = - AtenMovedimIntOp::create(rewriter, loc, gradWeightTypeForMoveDim, - gradWeight, /*source=*/cstZero, - /*destination=*/cstTwo); - - Value gradIntList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - llvm::ArrayRef{cstZero}); - gradWeight = Torch::AtenSumDimIntListOp::create( - rewriter, loc, op.getResultTypes()[1], /*self=*/gradWeight, - /*dim=*/gradIntList, - /*keepdim=*/cstFalse, - /*dtype=*/cstNone); - } else { - if (failed(getTransposedType(cast(gradOutput.getType()), - 0, 1, transposedType))) - return failure(); - Value gradOutputTransposed = Torch::AtenTransposeIntOp::create( - rewriter, loc, transposedType, gradOutput, cstZero, cstOne); - // Convolve input with grad_output. - if (failed( - getTransposedType(cast(op.getResultTypes()[1]), - 0, 1, transposedType))) - return failure(); - gradWeight = Torch::AtenConvolutionOp::create( - rewriter, loc, transposedType, inputTransposed, - gradOutputTransposed, cstNone, op.getStride(), op.getPadding(), - op.getDilation(), op.getTransposed(), op.getOutputPadding(), - op.getGroups()); - gradWeight = Torch::AtenTransposeIntOp::create( - rewriter, loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); - } - } - - Value gradBias = cstNone; - if (outMask[2]) { - // Computing Grad Bias. - SmallVector dimIntList{cstZero}; - for (unsigned i = 2; i < gradRank; i++) - dimIntList.push_back(Torch::ConstantIntOp::create( - rewriter, loc, rewriter.getI64IntegerAttr(i))); - Value gradIntList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - dimIntList); - - // Sum grad_output along dim 1. - gradBias = Torch::AtenSumDimIntListOp::create( - rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, - cstFalse, cstNone); - } - - rewriter.replaceOp(op, {gradInput, gradWeight, gradBias}); - return success(); - } -}; -} // namespace - /** * # one dim input * t = torch.tensor([0, 0, 1, 1, 0, 0] @@ -13159,7 +12759,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>( patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index cfc8bb96118b..a1bd4cc9328a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -444,7 +444,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py index 4f547d531294..095b4db6c5a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py @@ -34,6 +34,7 @@ "aten.flatten.using_ints", "aten.adaptive_avg_pool1d", "aten.adaptive_avg_pool2d", + "aten.convolution_backward", "aten.unflatten.int", ], OutputType.STABLEHLO: [ diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index 5e6e093902c4..e6bbbe3273dc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -228,6 +228,75 @@ def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3)) +class ConvolutionBackwardModule2DDilated(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 2, 6, 6], torch.float32, True), + ([1, 4, 8, 8], torch.float32, True), + ([2, 4, 3, 3], torch.float32, True), + ] + ) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[1, 1], + padding=[1, 1], + dilation=[2, 2], + transposed=False, + output_padding=[0, 0], + groups=1, + output_mask=[True, True, True], + ) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DDilated()) +def ConvolutionBackwardModule2DDilated_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(1, 2, 6, 6), tu.rand(1, 4, 8, 8), tu.rand(2, 4, 3, 3)) + + +class ConvolutionBackwardModule2DStridedPaddedDilatedGrouped(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 16, 32, 32], torch.float32, True), + ([2, 128, 64, 64], torch.float32, True), + ([16, 32, 2, 2], torch.float32, True), + ] + ) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[2, 2], + padding=[2, 2], + dilation=[4, 4], + transposed=False, + output_padding=[0, 0], + groups=4, + output_mask=[True, True, True], + ) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStridedPaddedDilatedGrouped()) +def ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(2, 16, 32, 32), tu.rand(2, 128, 64, 64), tu.rand(16, 32, 2, 2)) + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/convolution_bwd.mlir b/test/Conversion/TorchToLinalg/convolution_bwd.mlir new file mode 100644 index 000000000000..0e1f5e67dbb8 --- /dev/null +++ b/test/Conversion/TorchToLinalg/convolution_bwd.mlir @@ -0,0 +1,318 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,63,63],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,63,63],f32> -> tensor<2x16x63x63xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] low[0, 0, 1, 1] high[0, 0, 1, 1] + // CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK: tensor.yield %[[CST0]] : f32 + // CHECK: } : tensor<2x16x63x63xf32> to tensor<2x16x65x65xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[W_REV]] : tensor<2x16x65x65xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x63x63xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,63,63],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x16x66x66xf32> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x16x66x66xf32>) -> tensor<2x16x66x66xf32> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0]] into %[[SLICE_FILLED]][0, 0, 0, 0] [2, 16, 33, 33] [1, 1, 2, 2] : tensor<2x16x33x33xf32> into tensor<2x16x66x66xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d5 * 2 + d2, d6 * 2 + d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[SLICE]], %[[W_REV]] : tensor<2x16x66x66xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,32,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,32,2,2],f32> -> tensor<16x32x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xf32> into tensor<2x4x4x33x33xf32> + // CHECK: %[[W_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xf32> into tensor<4x4x32x2x2xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {{.*}} ins(%[[W_EXP]] : tensor<4x4x32x2x2xf32>) outs(%[[W_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST1]], %[[I4]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[W_EXP]][%[[I0]], %[[I1]], %[[I2]], %[[R3]], %[[R4]]] : tensor<4x4x32x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xf32> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xf32>) -> tensor<2x4x4x66x66xf32> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xf32> into tensor<2x4x4x66x66xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_EXP:.*]] = tensor.expand_shape %[[OUT_EMPTY]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xf32> into tensor<2x4x32x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EXP]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xf32>, tensor<4x4x32x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV_COLLAPSED:.*]] = tensor.collapse_shape %[[CONV]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} : tensor<2x4x32x64x64xf32> into tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV_COLLAPSED]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,32,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,63,63],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,63,63],f32> -> tensor<2x16x63x63xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[T1]], %[[T0]] : tensor<2x128x64x64xf32>, tensor<2x16x63x63xf32>) outs(%[[OUT0_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<16x128x2x2xf32> -> !torch.vtensor<[16,128,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x63x63xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,63,63],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> + return %result1, %result2 : !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,32,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[32,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,32,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[32,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,32,33,33],f32> -> tensor<2x32x33x33xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1]] low[0, 0, 2, 2] high[0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST]] : f32 + // CHECK-NEXT: } : tensor<2x128x64x64xf32> to tensor<2x128x68x68xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<32x128x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<32x128x2x2xf32>) -> tensor<32x128x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 * 2 + d5 * 2, d3 * 2 + d6 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0]] : tensor<2x128x68x68xf32>, tensor<2x32x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<32x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<32x128x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<32x128x2x2xf32> -> !torch.vtensor<[32,128,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<32xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<32xf32>) -> tensor<32xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x32x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<32xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<32xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<32xf32> -> !torch.vtensor<[32],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,32,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[32,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> + return %result1, %result2 : !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,32,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xf32> into tensor<2x4x4x33x33xf32> + // CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xf32> into tensor<2x4x32x64x64xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST]] : f32 + // CHECK-NEXT: } : tensor<2x4x32x64x64xf32> to tensor<2x4x32x68x68xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<16x32x2x2xf32> + // CHECK: %[[OUT0_EXP:.*]] = tensor.expand_shape %[[OUT0_EMPTY]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xf32> into tensor<4x4x32x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EXP]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xf32>, tensor<2x4x4x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CONV]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xf32> into tensor<16x32x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xf32> -> !torch.vtensor<[16,32,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,32,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> + return %result1, %result2 : !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> +} + +// ----- diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 8fe502a7d686..6c1d051b570b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -273,40 +273,6 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { return %arg0 : !torch.int } -// ----- - -// CHECK-LABEL: func.func @convolution_backward_none_result( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,1,5,5],f32>, -// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,1,3,3],f32>, -// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) { -func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,3,3],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) { - // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 - // CHECK: %[[VAL_5:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_6:.*]] = torch.constant.none - // CHECK: %[[VAL_7:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_8:.*]] = torch.constant.bool false - // CHECK: %[[VAL_9:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_9]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_12:.*]] = torch.aten.transpose.int %[[VAL_1]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,5,5],f32> - // CHECK: %[[VAL_13:.*]] = torch.aten.transpose.int %[[VAL_0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32> - // CHECK: %[[VAL_14:.*]] = torch.aten.convolution %[[VAL_12]], %[[VAL_13]], %[[VAL_6]], %[[VAL_10]], %[[VAL_11]], %[[VAL_10]], %[[VAL_8]], %[[VAL_11]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,3],f32> - // CHECK: %[[VAL_15:.*]] = torch.aten.transpose.int %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32> - // CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_5]], %[[VAL_4]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_17:.*]] = torch.aten.sum.dim_IntList %[[VAL_0]], %[[VAL_16]], %[[VAL_8]], %[[VAL_6]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1],f32> - // CHECK: return %[[VAL_15]], %[[VAL_17]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> - return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> -} - // ----- // CHECK-LABEL: func.func @emptyLikeNoneDtype( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {