diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 62fc212b23be..20f0d8b2a83f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include #include @@ -44,6 +45,158 @@ namespace mlir::torch { namespace { +static SmallVector +permuteShape(ArrayRef originalShape, + ArrayRef permutation) { + SmallVector result; + result.reserve(permutation.size()); + for (int32_t dim : permutation) + result.push_back(originalShape[dim]); + return result; +} + +struct ZeroInsertionResult { + Value value; + bool trimmedTail; +}; + +static FailureOr +insertZerosAlongAxis(Value input, int axis, int64_t stride, + ConversionPatternRewriter &rewriter, Location loc) { + if (stride == 1) + return ZeroInsertionResult{input, /*trimmedTail=*/true}; + + if (stride <= 0) + return failure(); + + auto inputType = dyn_cast(input.getType()); + if (!inputType || !inputType.hasStaticShape()) + return failure(); + + auto elementType = inputType.getElementType(); + SmallVector shape = llvm::to_vector(inputType.getShape()); + if (axis < 0 || axis >= static_cast(shape.size())) + return failure(); + + int64_t dim = shape[axis]; + if (ShapedType::isDynamic(dim)) + return failure(); + + SmallVector expandedShape; + expandedShape.reserve(shape.size() + 1); + for (int i = 0; i < static_cast(shape.size()); ++i) { + expandedShape.push_back(shape[i]); + if (i == axis) + expandedShape.push_back(1); + } + + auto expandedType = RankedTensorType::get( + makeShapeLLVMCompatible(expandedShape), elementType); + Value reshapeToExpanded = rewriter.create( + loc, expandedType, input, + tosa::getTosaConstShape(rewriter, loc, expandedShape)); + + SmallVector paddedShape = expandedShape; + paddedShape[axis + 1] = stride; + SmallVector pads(2 * expandedShape.size(), 0); + pads[2 * (axis + 1) + 1] = stride - 1; + + Value padsConst = tosa::getTosaConstShape(rewriter, loc, pads); + + Value padValue = + tosa::createZeroPointTensor(rewriter, loc, elementType, 0).value(); + + auto paddedType = RankedTensorType::get( + makeShapeLLVMCompatible(paddedShape), elementType); + Value padded = rewriter.create( + loc, paddedType, reshapeToExpanded, padsConst, padValue); + + SmallVector collapsedShape = shape; + collapsedShape[axis] = dim * stride; + auto collapsedType = RankedTensorType::get( + makeShapeLLVMCompatible(collapsedShape), elementType); + + Value result = rewriter.create( + loc, collapsedType, padded, + tosa::getTosaConstShape(rewriter, loc, collapsedShape)); + + bool trimmedTail = stride <= 1; + + if (dim != ShapedType::kDynamic && stride > 1) { + int64_t trimmedLength = (dim - 1) * stride + 1; + if (trimmedLength < collapsedShape[axis]) { + SmallVector startIndices(collapsedShape.size(), 0); + SmallVector sliceSizes = collapsedShape; + sliceSizes[axis] = trimmedLength; + SmallVector trimmedShape = + llvm::to_vector(collapsedType.getShape()); + trimmedShape[axis] = trimmedLength; + auto trimmedType = RankedTensorType::get( + makeShapeLLVMCompatible(trimmedShape), elementType); + result = rewriter.create( + loc, trimmedType, result, + tosa::getTosaConstShape(rewriter, loc, startIndices), + tosa::getTosaConstShape(rewriter, loc, sliceSizes)); + } + + trimmedTail = true; + } + + if (stride > 1 && dim == ShapedType::kDynamic) + trimmedTail = false; + + return ZeroInsertionResult{result, trimmedTail}; +} + +static LogicalResult +getTorchToTosaPermutations(Location loc, int64_t rank, + SmallVectorImpl &torchToTosa, + SmallVectorImpl &tosaToTorch) { + if (rank < 3) + return emitError(loc) << "expected convolution tensor rank >= 3, got " + << rank; + + torchToTosa.clear(); + tosaToTorch.clear(); + + torchToTosa.push_back(0); // batch dim stays first + for (int64_t dim = 2; dim < rank; ++dim) + torchToTosa.push_back(dim); // spatial dims in order + torchToTosa.push_back(1); // channel moves to last position + + tosaToTorch.resize(torchToTosa.size()); + for (auto pair : llvm::enumerate(torchToTosa)) + tosaToTorch[pair.value()] = pair.index(); + + return success(); +} + +static LogicalResult +getTorchConvWeightPermutation(Location loc, int64_t rank, bool isTransposed, + SmallVectorImpl &permutation) { + if (rank < 3) + return emitError(loc) << "expected convolution weight rank >= 3, got " + << rank; + + permutation.clear(); + + if (!isTransposed) { + // Torch weight layout: [O, I, spatial...]; TOSA expects [O, spatial..., I]. + permutation.push_back(0); + for (int64_t dim = 2; dim < rank; ++dim) + permutation.push_back(dim); + permutation.push_back(1); + } else { + // Transposed layout: [I, O, spatial...] -> [O, spatial..., I]. + permutation.push_back(1); + for (int64_t dim = 2; dim < rank; ++dim) + permutation.push_back(dim); + permutation.push_back(0); + } + + return success(); +} + // These legalizations are for unary ops with promoting input to floating-point // datatypes only. There is no supported quantized integer mode for these. template @@ -2377,9 +2530,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto weightShape = makeShapeTorchCompatible(weightTy.getShape()); auto outputElemTy = outputTy.getElementType(); - if (inputTy.getRank() != 4) + int64_t inputRank = inputTy.getRank(); + int64_t weightRank = weightTy.getRank(); + int64_t outputRank = outputTy.getRank(); + + if (inputRank != weightRank || outputRank != inputRank) + return rewriter.notifyMatchFailure( + op, "Input, weight and output ranks must match for convolution"); + + if (inputRank != 4 && inputRank != 5) return rewriter.notifyMatchFailure( - op, "Unimplemented: only 2D convolutions supported"); + op, "Unimplemented: only 2D or 3D convolutions supported"); + + bool is3D = inputRank == 5; + int64_t spatialRank = inputRank - 2; if (!weightTy.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -2390,12 +2554,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bias = adaptor.getBias(); if (isa(bias.getType())) { - // ConvTranspose weights use IOHW; the helper expects OIHW, so swap - // dims 0/1 before we synthesize the bias. - SmallVector biasWeightShape = - transposed ? SmallVector{weightShape[1], weightShape[0], - weightShape[2], weightShape[3]} - : weightShape; + // For transposed conv (Torch weight = [IC, OC, KH, KW]) we want the bias + // length to match OC, so we pass a "conv-style" shape with OC in dim 0. + SmallVector biasWeightShape; + if (transposed) { + biasWeightShape.push_back(weightShape[1]); + biasWeightShape.push_back(weightShape[0]); + biasWeightShape.append(weightShape.begin() + 2, weightShape.end()); + } else { + biasWeightShape = weightShape; + } auto biasResult = tosa::getConvBiasForNoneType( op, rewriter, inputElemTy, outputElemTy, biasWeightShape); @@ -2420,26 +2588,36 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "(depthwise convolution)"); } - SmallVector stride; + SmallVector stride; if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride))) return rewriter.notifyMatchFailure(op, "non-const stride list unsupported"); + if (static_cast(stride.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "stride rank mismatch"); - SmallVector padding_2d; + SmallVector paddingList; if (!matchPattern(adaptor.getPadding(), - m_TorchListOfConstantInts(padding_2d))) + m_TorchListOfConstantInts(paddingList))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); - // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D - // padding {height, width}. The PyTorch OFM computation uses 2*pad in each - // spatial direction, implying the same top=bottom=height and left=right=width - // values for TOSA. - SmallVector padding( - {padding_2d[0], padding_2d[0], padding_2d[1], padding_2d[1]}); - - SmallVector dilation; + if (static_cast(paddingList.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "padding rank mismatch"); + + // TOSA expects symmetric before/after padding per spatial dimension. + SmallVector padding; + if (is3D) { + padding = {paddingList[0], paddingList[0], paddingList[1], paddingList[1], + paddingList[2], paddingList[2]}; + } else { + padding = {paddingList[0], paddingList[0], paddingList[1], + paddingList[1]}; + } + + SmallVector dilation; if (!matchPattern(adaptor.getDilation(), m_TorchListOfConstantInts(dilation))) return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); + if (static_cast(dilation.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "dilation rank mismatch"); TypeAttr accType; if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy, @@ -2447,54 +2625,208 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get accumulator type for convolution ops"); - // Weight layout reference: - // Conv : PyTorch OIHW -> TOSA OHWI - // Depthwise : PyTorch OIHW* -> TOSA HWIM - // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier) - // Grouped : PyTorch O(I/G)HW -> N/A - // Transposed : PyTorch IOHW -> TOSA OHWI - // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. - // Perform the necessary transformations. - SmallVector nchwToNhwcDims({0, 2, 3, 1}); - SmallVector nhwcToNchwDims({0, 3, 1, 2}); - SmallVector transposedInputShape; - for (int32_t dim : nchwToNhwcDims) - transposedInputShape.push_back(inputShape[dim]); + + // TOSA works in NHWC (2D) / NDHWC (3D) and takes OHWI / ODHWI weights for + // convolution. Perform the necessary transformations. + SmallVector torchToTosaDims; + SmallVector tosaToTorchDims; + if (failed(getTorchToTosaPermutations(op->getLoc(), inputRank, + torchToTosaDims, tosaToTorchDims))) + return rewriter.notifyMatchFailure(op, + "unsupported convolution input rank"); + + SmallVector transposedInputShape = + permuteShape(inputShape, torchToTosaDims); auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); - auto createTransposedInput = [&]() { - return tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transposedInputType), input, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) - .getResult(); - }; + auto transposedInput = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transposedInputType), input, + rewriter.getDenseI32ArrayAttr(torchToTosaDims)) + .getResult(); if (transposed) { - if (groups != 1) - return rewriter.notifyMatchFailure( - op, "Unimplemented: grouped transposed convolution not supported by " - "TOSA"); - if (dilation[0] != 1 || dilation[1] != 1) + SmallVector outPaddingList; + if (!matchPattern(adaptor.getOutputPadding(), + m_TorchListOfConstantInts(outPaddingList))) return rewriter.notifyMatchFailure( - op, "Unimplemented: dilated transposed convolution not supported by " - "TOSA"); + op, "non-const output_padding list unsupported for transposed conv"); + if (static_cast(outPaddingList.size()) != spatialRank) + return rewriter.notifyMatchFailure(op, "output_padding rank mismatch"); - SmallVector iohwToOhwi({1, 2, 3, 0}); + if (is3D) { + if (groups != 1) + return rewriter.notifyMatchFailure( + op, "Unimplemented: grouped transposed 3D convolution not " + "supported by TOSA"); - // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. - // Map from PyTorch's (padding, output_padding): - // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) - // Negative values are allowed and will be handled by the TOSA - // decomposition. - SmallVector outPadding2D; - if (!matchPattern(adaptor.getOutputPadding(), - m_TorchListOfConstantInts(outPadding2D))) + SmallVector weightPermutation; + if (failed(getTorchConvWeightPermutation(op->getLoc(), weightRank, + /*isTransposed=*/true, + weightPermutation))) + return rewriter.notifyMatchFailure( + op, "unsupported convolution weight rank for transpose conv"); + + SmallVector transformedWeightShape = + permuteShape(weightShape, weightPermutation); + auto transformedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); + Value transformedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), + weight, rewriter.getDenseI32ArrayAttr(weightPermutation)) + .getResult(); + + // Reverse spatial dims of the kernel. + Value flippedWeight = transformedWeight; + for (int reverseAxis = 1; reverseAxis <= 3; ++reverseAxis) { + flippedWeight = rewriter.create( + op->getLoc(), flippedWeight.getType(), flippedWeight, + rewriter.getI32IntegerAttr(reverseAxis)); + } + + Value upsampledInput = transposedInput; + SmallVector tailTrimmed(/*N=*/3, /*Value=*/true); + Location loc = op->getLoc(); + for (int axis = 0; axis < 3; ++axis) { + auto insertedResult = + insertZerosAlongAxis(upsampledInput, axis + 1, stride[axis], + rewriter, loc); + if (failed(insertedResult)) + return rewriter.notifyMatchFailure( + op, "Unsupported parameters for transposed 3D convolution"); + upsampledInput = insertedResult->value; + tailTrimmed[axis] = insertedResult->trimmedTail; + } + + auto upsampledType = cast(upsampledInput.getType()); + SmallVector upsampledShape = + llvm::to_vector(upsampledType.getShape()); + + SmallVector padVec(2 * upsampledShape.size(), 0); + SmallVector paddedShape = upsampledShape; + for (int axis = 0; axis < 3; ++axis) { + int spatialIndex = axis + 1; // NDHWC ordering + int64_t kernel = weightShape[2 + axis]; + int64_t dil = dilation[axis]; + int64_t pad = paddingList[axis]; + int64_t outPad = outPaddingList[axis]; + int64_t before = dil * (kernel - 1) - pad; + int64_t after = before + outPad; + if (!tailTrimmed[axis]) { + after -= (stride[axis] - 1); + } + if (before < 0 || after < 0) + return rewriter.notifyMatchFailure( + op, "Unsupported padding combination for transposed 3D " + "convolution"); + padVec[2 * spatialIndex] = before; + padVec[2 * spatialIndex + 1] = after; + if (!ShapedType::isDynamic(paddedShape[spatialIndex])) + paddedShape[spatialIndex] += before + after; + else + paddedShape[spatialIndex] = ShapedType::kDynamic; + } + + Value paddedInput = upsampledInput; + if (llvm::any_of(padVec, [](int64_t v) { return v != 0; })) { + Value padsConst = tosa::getTosaConstShape(rewriter, loc, padVec); + Value padTensor = tosa::createZeroPointTensor(rewriter, loc, + inputElemTy, 0) + .value(); + auto paddedType = RankedTensorType::get( + makeShapeLLVMCompatible(paddedShape), inputElemTy); + paddedInput = rewriter.create(loc, paddedType, + upsampledInput, padsConst, + padTensor); + } + + auto outTorchShape = makeShapeTorchCompatible(outputTy.getShape()); + SmallVector outTosaShape = + permuteShape(outTorchShape, torchToTosaDims); + auto convOpTy = RankedTensorType::get( + makeShapeLLVMCompatible(outTosaShape), biasElemTy); + + auto zps = tosa::createZPsAsConst(rewriter, input, weight); + Value inputZp = zps.first + ? zps.first + : tosa::createZeroPointTensor(rewriter, loc, + inputElemTy, 0) + .value(); + Value weightZp = zps.second + ? zps.second + : tosa::createZeroPointTensor(rewriter, loc, + weightElemTy, 0) + .value(); + + auto convResult = + rewriter + .create( + loc, getTypeConverter()->convertType(convOpTy), + paddedInput, flippedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr({0, 0, 0, 0, 0, 0}), + rewriter.getDenseI64ArrayAttr({1, 1, 1}), + rewriter.getDenseI64ArrayAttr(dilation), accType) + .getResult(); + + SmallVector transposedOutputShape = + permuteShape(outTosaShape, tosaToTorchDims); + auto transposedOutputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); + Value transposedOutput = + rewriter + .create( + loc, getTypeConverter()->convertType(transposedOutputType), + convResult, rewriter.getDenseI32ArrayAttr(tosaToTorchDims)) + .getResult(); + + Value rescaledResult = transposedOutput; + if (isa(inputElemTy)) { + rescaledResult = tosa::buildRescaleOpConvOutput( + rewriter, op, transposedOutput, inputTy, weightTy, outputTy); + } + + rewriter.replaceOp( + op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) + .value()}); + return success(); + } + + if (groups != 1) + return rewriter.notifyMatchFailure(op, + "Unimplemented: grouped transposed " + "convolution not supported by TOSA"); + if (dilation[0] != 1 || dilation[1] != 1) + return rewriter.notifyMatchFailure(op, + "Unimplemented: dilated transposed " + "convolution not supported by TOSA"); + + // Torch (ConvTranspose2d) weight is [IC, OC, KH, KW]. + // TOSA transpose_conv2d requires [OC, KH, KW, IC]. + SmallVector weightPermutation; + if (failed(getTorchConvWeightPermutation(op->getLoc(), weightRank, + /*isTransposed=*/true, + weightPermutation))) return rewriter.notifyMatchFailure( - op, "non-const output_padding list unsupported for transposed conv"); + op, "unsupported convolution weight rank for transpose conv"); + + SmallVector ohwiWeightShape = + permuteShape(weightShape, weightPermutation); + auto ohwiWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); + Value transformedWeight = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(ohwiWeightType), + weight, rewriter.getDenseI32ArrayAttr(weightPermutation)) + .getResult(); - int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; - int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; + int64_t outPadH = outPaddingList[0] - 2 * paddingList[0]; + int64_t outPadW = outPaddingList[1] - 2 * paddingList[1]; int64_t outPadTop = outPadH / 2; int64_t outPadBottom = outPadH - outPadTop; int64_t outPadLeft = outPadW / 2; @@ -2502,28 +2834,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector outPad( {outPadTop, outPadBottom, outPadLeft, outPadRight}); - Value nhwcInput = createTransposedInput(); - SmallVector ohwiWeightShape; - for (int32_t dim : iohwToOhwi) - ohwiWeightShape.push_back(weightShape[dim]); - auto ohwiWeightType = RankedTensorType::get( - makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); - Value transformedWeight = - tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(ohwiWeightType), weight, - rewriter.getDenseI32ArrayAttr(iohwToOhwi)) - .getResult(); - // Result type is NHWC (we'll transpose back). - auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); - SmallVector outNHWC; - for (int32_t dim : nchwToNhwcDims) - outNHWC.push_back(outNCHW[dim]); + auto outTorchShape = makeShapeTorchCompatible(outputTy.getShape()); + SmallVector outTosaShape = + permuteShape(outTorchShape, torchToTosaDims); auto transConvOpTy = - RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); + RankedTensorType::get(makeShapeLLVMCompatible(outTosaShape), + biasElemTy); - // Zero-points. + // Zero-points (same helpers as conv path). auto zps = tosa::createZPsAsConst(rewriter, input, weight); Value inputZp = zps.first ? zps.first : tosa::createZeroPointTensor( @@ -2534,27 +2853,35 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter, op->getLoc(), weightElemTy, 0) .value(); - Value convTOut = tosa::TransposeConv2DOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transConvOpTy), - nhwcInput, transformedWeight, bias, inputZp, weightZp, - rewriter.getDenseI64ArrayAttr(outPad), - rewriter.getDenseI64ArrayAttr(stride), accType) - .getResult(); + // Build tosa.transpose_conv2d. + Value convTOut = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(transConvOpTy), + /*input*/ transposedInput, + /*weight*/ transformedWeight, + /*bias*/ bias, + /*input_zp*/ inputZp, + /*weight_zp*/ weightZp, + /*out_pad*/ rewriter.getDenseI64ArrayAttr(outPad), + /*stride*/ rewriter.getDenseI64ArrayAttr(stride), + /*acc_type*/ accType) + .getResult(); - SmallVector transposedOutputShape; - for (int32_t dim : nhwcToNchwDims) - transposedOutputShape.push_back(outNHWC[dim]); + // NHWC -> NCHW + SmallVector transposedOutputShape = + permuteShape(outTosaShape, tosaToTorchDims); auto transposedOutputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); Value transposedOutput = - tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transposedOutputType), convTOut, - rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transposedOutputType), convTOut, + rewriter.getDenseI32ArrayAttr(tosaToTorchDims)) .getResult(); - // Quantized rescale. + // Quantized rescale (reuse existing helper). Value rescaledResult = transposedOutput; if (isa(inputElemTy)) { rescaledResult = tosa::buildRescaleOpConvOutput( @@ -2567,25 +2894,31 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value()}); return success(); } - SmallVector transformedWeightShape; RankedTensorType transformedWeightType; Value transformedWeight; int64_t outputCDim; + SmallVector weightPermutation; + if (failed(getTorchConvWeightPermutation(op->getLoc(), weightRank, + /*isTransposed=*/false, + weightPermutation))) + return rewriter.notifyMatchFailure( + op, "unsupported convolution weight rank"); + if (groups == 1) { - // full convolution: O(I/G)HW-> OHWI - transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3], - weightShape[1]}; + // full convolution: O(I/G)spatial -> O spatial I + transformedWeightShape = permuteShape(weightShape, weightPermutation); transformedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); transformedWeight = - tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transformedWeightType), weight, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), weight, + rewriter.getDenseI32ArrayAttr(weightPermutation)) .getResult(); outputCDim = transformedWeightShape[0]; - } else if (weightShape[1] == 1) { + } else if (!is3D && weightShape[1] == 1) { // depthwise convolution: O(I/G)HW-> HWIM) // transpose: O(I/G)HW -> HWO(I/G) SmallVector transposedDims({2, 3, 0, 1}); @@ -2626,84 +2959,163 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transformedWeightShape)) .getResult(); } else { - llvm_unreachable("Unhandled convolution type"); - } - - Value transposedInput = createTransposedInput(); - - int64_t outputHDim, outputWDim; - int64_t inputHDim = inputShape[2]; - int64_t inputWDim = inputShape[3]; - - bool isStaticSpatialDims = - !ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim); - if (isStaticSpatialDims) { - - int64_t weightHDim = weightShape[2]; - int64_t weightWDim = weightShape[3]; - - // fullDim = - // inputDim + padBefore + padAfter - dilation * (weightDim - 1) - 1 - // According to TOSA spec: - // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d, fullDim values - // must be divisible by stride values. - int64_t fullHDim = inputHDim + padding[0] + padding[1] - - dilation[0] * (weightHDim - 1) - 1; - int64_t remainderHDim = fullHDim % stride[0]; - if (remainderHDim != 0) { - if (remainderHDim > padding[1]) { - SmallVector startHSlice(inputTy.getRank(), 0); - SmallVector sizeHSlice(transposedInputShape); - // TOSA uses NHWC, so we will slice dim 1 for Height value - sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); - transposedInput = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), - transposedInput, - tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice), - tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice)); - fullHDim = fullHDim - padding[1]; - padding[1] = 0; - } else { - fullHDim = fullHDim - padding[1]; - padding[1] = padding[1] - remainderHDim; - fullHDim = fullHDim + padding[1]; + return rewriter.notifyMatchFailure( + op, is3D ? "Unimplemented: grouped or depthwise 3D convolution " + "not supported by TOSA" + : "Unhandled convolution type"); + } + + + SmallVector outputShape; + if (!is3D) { + int64_t outputHDim, outputWDim; + int64_t inputHDim = inputShape[2]; + int64_t inputWDim = inputShape[3]; + + bool isStaticSpatialDims = + !ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim); + if (isStaticSpatialDims) { + int64_t weightHDim = weightShape[2]; + int64_t weightWDim = weightShape[3]; + + + // fullDim = inputDim + padBefore + padAfter - dilation * (weightDim - 1) + // - 1. According to TOSA spec fullDim values must be divisible by the + // stride values. + int64_t fullHDim = inputHDim + padding[0] + padding[1] - + dilation[0] * (weightHDim - 1) - 1; + int64_t remainderHDim = fullHDim % stride[0]; + if (remainderHDim != 0) { + if (remainderHDim > padding[1]) { + SmallVector startHSlice(inputTy.getRank(), 0); + SmallVector sizeHSlice(transposedInputShape); + // TOSA uses NHWC, so slice dim 1 for Height. + sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice)); + fullHDim = fullHDim - padding[1]; + padding[1] = 0; + } else { + fullHDim = fullHDim - padding[1]; + padding[1] = padding[1] - remainderHDim; + fullHDim = fullHDim + padding[1]; + } } - } - outputHDim = fullHDim / stride[0] + 1; - - int64_t fullWDim = inputWDim + padding[2] + padding[3] - - dilation[1] * (weightWDim - 1) - 1; - int64_t remainderWDim = fullWDim % stride[1]; - if (remainderWDim != 0) { - if (remainderWDim > padding[3]) { - SmallVector startWSlice(inputTy.getRank(), 0); - SmallVector sizeWSlice( - dyn_cast(transposedInput.getType()).getShape()); - // TOSA uses NHWC, so we will slice dim 2 for Width value - sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]); - transposedInput = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), - transposedInput, - tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice), - tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice)); - fullHDim = fullHDim - padding[3]; - padding[3] = 0; - } else { - fullWDim = fullWDim - padding[3]; - padding[3] = padding[3] - remainderWDim; - fullWDim = fullWDim + padding[3]; + outputHDim = fullHDim / stride[0] + 1; + + int64_t fullWDim = inputWDim + padding[2] + padding[3] - + dilation[1] * (weightWDim - 1) - 1; + int64_t remainderWDim = fullWDim % stride[1]; + if (remainderWDim != 0) { + if (remainderWDim > padding[3]) { + SmallVector startWSlice(inputTy.getRank(), 0); + SmallVector sizeWSlice( + dyn_cast(transposedInput.getType()).getShape()); + // TOSA uses NHWC, so slice dim 2 for Width. + sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]); + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice)); + fullWDim = fullWDim - padding[3]; + padding[3] = 0; + } else { + fullWDim = fullWDim - padding[3]; + padding[3] = padding[3] - remainderWDim; + fullWDim = fullWDim + padding[3]; + } } + outputWDim = fullWDim / stride[1] + 1; + } else { + outputHDim = kUnknownSize; + outputWDim = kUnknownSize; } - outputWDim = fullWDim / stride[1] + 1; + + outputShape = {transposedInputShape[0], outputHDim, outputWDim, outputCDim}; } else { - outputHDim = kUnknownSize; - outputWDim = kUnknownSize; - } + int64_t outputDDim, outputHDim, outputWDim; + int64_t inputDDim = inputShape[2]; + int64_t inputHDim = inputShape[3]; + int64_t inputWDim = inputShape[4]; + + bool isStaticSpatialDims = !ShapedType::isDynamic(inputDDim) && + !ShapedType::isDynamic(inputHDim) && + !ShapedType::isDynamic(inputWDim); + SmallVector currentShape = transposedInputShape; + + auto adjustSpatialDim = [&](int axis, int padBeforeIdx, int padAfterIdx, + int64_t weightDim, int64_t strideVal, + int64_t dilationVal, int64_t &outputDim) + -> LogicalResult { + int nhwcAxis = axis + 1; + int64_t inputDim = currentShape[nhwcAxis]; + int64_t fullDim = inputDim + padding[padBeforeIdx] + padding[padAfterIdx] - + dilationVal * (weightDim - 1) - 1; + int64_t remainder = fullDim % strideVal; + if (remainder != 0) { + if (remainder > padding[padAfterIdx]) { + SmallVector startSlice(currentShape.size(), 0); + SmallVector sizeSlice = currentShape; + sizeSlice[nhwcAxis] = + inputDim - (remainder - padding[padAfterIdx]); + transposedInput = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); + fullDim = fullDim - padding[padAfterIdx]; + padding[padAfterIdx] = 0; + } else { + fullDim = fullDim - padding[padAfterIdx]; + padding[padAfterIdx] = padding[padAfterIdx] - remainder; + fullDim = fullDim + padding[padAfterIdx]; + } + } + outputDim = fullDim / strideVal + 1; + return success(); + }; + + if (isStaticSpatialDims) { + int64_t weightDDim = weightShape[2]; + int64_t weightHDim = weightShape[3]; + int64_t weightWDim = weightShape[4]; + + if (failed(adjustSpatialDim(/*axis=*/0, /*padBeforeIdx=*/0, + /*padAfterIdx=*/1, weightDDim, stride[0], + dilation[0], outputDDim))) + return failure(); + if (auto updatedType = + dyn_cast(transposedInput.getType())) + currentShape = llvm::to_vector(updatedType.getShape()); + + if (failed(adjustSpatialDim(/*axis=*/1, /*padBeforeIdx=*/2, + /*padAfterIdx=*/3, weightHDim, stride[1], + dilation[1], outputHDim))) + return failure(); + if (auto updatedType = + dyn_cast(transposedInput.getType())) + currentShape = llvm::to_vector(updatedType.getShape()); + + if (failed(adjustSpatialDim(/*axis=*/2, /*padBeforeIdx=*/4, + /*padAfterIdx=*/5, weightWDim, stride[2], + dilation[2], outputWDim))) + return failure(); + if (auto updatedType = + dyn_cast(transposedInput.getType())) + currentShape = llvm::to_vector(updatedType.getShape()); + } else { + outputDDim = kUnknownSize; + outputHDim = kUnknownSize; + outputWDim = kUnknownSize; + } - // Output shape is NHWC, to be transposed back to NCHW. Output elemTy for - // quantized input is i32, which gets rescaled down to quantized output range. - SmallVector outputShape = {transposedInputShape[0], outputHDim, - outputWDim, outputCDim}; + outputShape = {currentShape[0], outputDDim, outputHDim, outputWDim, + outputCDim}; + } auto convOpTy = RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); @@ -2725,15 +3137,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value convOpResult; if (groups == 1) { // full convolution - convOpResult = - tosa::Conv2DOp::create( - rewriter, op->getLoc(), getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, inputZp, weightZp, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation), accType) - .getResult(); - } else if (weightShape[1] == 1) { + if (is3D) { + convOpResult = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) + .getResult(); + } else { + convOpResult = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) + .getResult(); + } + } else if (!is3D && weightShape[1] == 1) { // depthwise convolution convOpResult = tosa::DepthwiseConv2DOp::create( @@ -2744,18 +3169,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else { - llvm_unreachable("Unhandled convolution type"); + return rewriter.notifyMatchFailure( + op, is3D ? "Unimplemented: grouped or depthwise 3D convolution " + "not supported by TOSA" + : "Unhandled convolution type"); } - SmallVector transposedOutputShape( - {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); + + SmallVector transposedOutputShape = + permuteShape(outputShape, tosaToTorchDims); auto transposedOutputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); auto transposedOutput = - tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transposedOutputType), convOpResult, - rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transposedOutputType), + convOpResult, rewriter.getDenseI32ArrayAttr(tosaToTorchDims)) .getResult(); Value rescaledResult = transposedOutput; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4258547b94ee..6a6da250ae3c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -70,9 +70,6 @@ "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dModule_basic", - "Conv_Transpose2dStaticModule_basic", - "Conv_Transpose3dModule_basic", - "Conv_Transpose3dStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "ConvolutionModule2DTransposeStrided_basic", "GridSamplerBasic1_basic", @@ -1138,8 +1135,6 @@ "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose2dStaticModule_basic", - "Conv_Transpose3dStaticModule_basic", "ConstantPad2dStaticModule_basic", "ConstantPadNdModule_basic", "ConstantPadNdPartialStaticModule_basic", @@ -2170,6 +2165,9 @@ "Conv2dWithValidPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Convolution2DStaticModule_basic", + "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", "CosineSimilarityStaticModule_basic", "DetachModule_basic", "DropoutEvalFloatModule_basic", @@ -2906,12 +2904,6 @@ "Conv2dWithPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Conv2dWithValidPaddingModule_basic", - "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", - "ConvolutionModule3DGroups_basic", - "ConvolutionModule3DGroupsStrided_basic", - "ConvolutionModule3DGroupsDilated_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -3370,7 +3362,6 @@ # Failure - unknown "BernoulliModule_basic", "Conv_Transpose1dModule_basic", - "Conv_Transpose3dModule_basic", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", @@ -3581,8 +3572,6 @@ "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose3dModule_basic", - "Conv_Transpose3dStaticModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3688,9 +3677,6 @@ "Conv2dQInt8PerChannelModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", @@ -4111,9 +4097,6 @@ "AvgPool3dStaticModule_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose2dStaticModule_basic", - "Conv_Transpose3dModule_basic", - "Conv_Transpose3dStaticModule_basic", "ElementwiseFmaxModule_basic", "ElementwiseFminModule_basic", "ElementwiseGeluApproximateTanhModule_basic", @@ -4327,9 +4310,6 @@ "Conv2dWithPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Conv2dWithValidPaddingModule_basic", - "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index bb04a9772f99..74bb15243b49 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3627,11 +3627,11 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<5x14x24x10xf32>) -> tensor<5x10x14x24xf32> // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> // CHECK: return %[[VAL_18]] : !torch.vtensor<[5,10,14,24],f32> @@ -3666,13 +3666,13 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> // CHECK: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_18:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.depthwise_conv2d %[[VAL_16]], %[[VAL_15]], %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_19:.*]] = tosa.depthwise_conv2d %[[VAL_13]], %[[VAL_16]], %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> // CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_19]] {perms = array} : (tensor<5x5x10x4xf32>) -> tensor<5x4x5x10xf32> // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[5,4,5,10],f32> @@ -3694,6 +3694,70 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3 // ----- +// CHECK-LABEL: func.func @torch.aten.convolution$3d_basic( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[2,3,5,6,7],f32>) -> !torch.vtensor<[2,4,5,6,7],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,6,7],f32> -> tensor<2x3x5x6x7xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource : tensor<4x3x3x3x3xf32>}> : () -> tensor<4x3x3x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<2x3x5x6x7xf32>) -> tensor<2x5x6x7x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<4x3x3x3x3xf32>) -> tensor<4x3x3x3x3xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.conv3d %[[VAL_11]], %[[VAL_12]], %[[VAL_10]], %[[VAL_13]], %[[VAL_14]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x5x6x7x3xf32>, tensor<4x3x3x3x3xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x6x7x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_15]] {perms = array} : (tensor<2x5x6x7x4xf32>) -> tensor<2x4x5x6x7xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<2x4x5x6x7xf32> -> !torch.vtensor<[2,4,5,6,7],f32> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[2,4,5,6,7],f32> +func.func @torch.aten.convolution$3d_basic(%arg0: !torch.vtensor<[2,3,5,6,7],f32>) -> !torch.vtensor<[2,4,5,6,7],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense_resource : tensor<4x3x3x3x3xf32>) : !torch.vtensor<[4,3,3,3,3],f32> + %none = torch.constant.none + %1 = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[2,3,5,6,7],f32>, !torch.vtensor<[4,3,3,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[2,4,5,6,7],f32> + return %5 : !torch.vtensor<[2,4,5,6,7],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$3d_transpose( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1,2,2,2],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,2,2,2],f32> -> tensor<1x1x2x2x2xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x3x3x3xf32>}> : () -> tensor<1x1x3x3x3xf32> +// CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x1x2x2x2xf32>) -> tensor<1x2x2x2x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_2]] {perms = array} : (tensor<1x1x3x3x3xf32>) -> tensor<1x3x3x3x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 1 : i32} +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} +// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 3 : i32} +// CHECK: %[[VAL_8:.*]] = tosa.conv3d %{{.*}}, %[[VAL_7]], %{{.*}}, %{{.*}}, %{{.*}} +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_8]] {perms = array} : (tensor<1x4x4x4x1xf32>) -> tensor<1x1x4x4x4xf32> +// CHECK: return %{{.*}} : !torch.vtensor<[1,1,4,4,4],f32> +func.func @torch.aten.convolution$3d_transpose(%arg0: !torch.vtensor<[1,1,2,2,2],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x1x3x3x3xf32>) : !torch.vtensor<[1,1,3,3,3],f32> + %bias = torch.vtensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> + %stride = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %out_padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %groups = torch.constant.int 1 + %result = torch.aten.convolution %arg0, %weight, %bias, %stride, %padding, %dilation, %true, %out_padding, %groups : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,3,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32> + return %result : !torch.vtensor<[1,1,4,4,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.convolution$zero_pad_with_sliced_input( // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,128,28,28],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,56,56],f32> -> tensor<1x64x56x56xf32> @@ -3708,17 +3772,17 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3 // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {values = dense<[1, 55, 56, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_17:.*]] = tosa.slice %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> +// CHECK: %[[VAL_17:.*]] = tosa.slice %[[VAL_13]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 55, 55, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : (tensor<1x55x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x55x64xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.conv2d %[[VAL_20]], %[[VAL_13]], %[[VAL_12]], %[[VAL_21]], %[[VAL_22]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> +// CHECK: %[[VAL_23:.*]] = tosa.conv2d %[[VAL_20]], %[[VAL_14]], %[[VAL_12]], %[[VAL_21]], %[[VAL_22]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> // CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_23]] {perms = array} : (tensor<1x28x28x128xf32>) -> tensor<1x128x28x28xf32> // CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<1x128x28x28xf32> -> !torch.vtensor<[1,128,28,28],f32> // CHECK: return %[[VAL_25]] : !torch.vtensor<[1,128,28,28],f32> @@ -3753,11 +3817,11 @@ func.func @torch.aten.convolution$zero_pad_with_sliced_input(%arg0: !torch.vtens // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<1x112x112x32xf32>) -> tensor<1x32x112x112xf32> // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x32x112x112xf32> -> !torch.vtensor<[1,32,112,112],f32> // CHECK: return %[[VAL_18]] : !torch.vtensor<[1,32,112,112],f32> @@ -3791,17 +3855,17 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> // CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> // CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x224x3xf32> // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_12]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor<1x75x75x32xf32>) -> tensor<1x32x75x75xf32> // CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<1x32x75x75xf32> -> !torch.vtensor<[1,32,75,75],f32> // CHECK: return %[[VAL_24]] : !torch.vtensor<[1,32,75,75],f32> @@ -3836,11 +3900,11 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor) -> tensor // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,32,112,112],f32> // CHECK: return %[[VAL_18]] @@ -3875,17 +3939,17 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_IN:.*]] {perms = array} : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> // CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor // CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_12]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor) -> tensor // CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor -> !torch.vtensor<[?,32,75,75],f32> // CHECK: return %[[VAL_24]] diff --git a/test/Conversion/TorchToTosa/conv3d_transpose.mlir b/test/Conversion/TorchToTosa/conv3d_transpose.mlir new file mode 100644 index 000000000000..ddca44e73d7c --- /dev/null +++ b/test/Conversion/TorchToTosa/conv3d_transpose.mlir @@ -0,0 +1,31 @@ +// RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @convtranspose3d( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,1,2,2,2],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { +// CHECK: %[[BT:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,1,2,2,2],f32> -> tensor<1x1x2x2x2xf32> +// CHECK: %[[WCONST:.*]] = "tosa.const" +// CHECK: %[[PERM_INPUT:.*]] = tosa.transpose %[[BT]] {perms = array} +// CHECK: %[[PERM_WEIGHT:.*]] = tosa.transpose %[[WCONST]] {perms = array} +// CHECK: %[[REV1:.*]] = tosa.reverse %[[PERM_WEIGHT]] {axis = 1 : i32} +// CHECK: %[[REV2:.*]] = tosa.reverse %[[REV1]] {axis = 2 : i32} +// CHECK: %[[REV3:.*]] = tosa.reverse %[[REV2]] {axis = 3 : i32} +// CHECK: %[[CONV:.*]] = tosa.conv3d %{{.*}} %[[REV3]] +// CHECK: %[[FINAL:.*]] = tosa.transpose %[[CONV]] {perms = array} +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[FINAL]] +// CHECK: return %[[RESULT]] +// CHECK: } +func.func @convtranspose3d(%input: !torch.vtensor<[1,1,2,2,2],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x1x3x3x3xf32>) : !torch.vtensor<[1,1,3,3,3],f32> + %bias = torch.vtensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> + %stride = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %out_padding = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %groups = torch.constant.int 1 + %result = torch.aten.convolution %input, %weight, %bias, %stride, %padding, %dilation, %true, %out_padding, %groups : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,3,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32> + return %result : !torch.vtensor<[1,1,4,4,4],f32> +}