diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0c152903b836..ef826b78ec94 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2380,8 +2380,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // padding {height, width}. The PyTorch OFM computation uses 2*pad in each // spatial direction, implying the same top=bottom=height and left=right=width // values for TOSA. - SmallVector padding( - {padding_2d[0], padding_2d[0], padding_2d[1], padding_2d[1]}); + + int64_t padH = padding_2d[0]; + // When padding is 'Valid', Torch produces 1D padding with only one value. + int64_t padW = (padding_2d.size() > 1) ? padding_2d[1] : padding_2d[0]; + + SmallVector padding({padH, padH, padW, padW}); SmallVector dilation; if (!matchPattern(adaptor.getDilation(), m_TorchListOfConstantInts(dilation))) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d2e46b86fe00..bd33f1559b98 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3891,6 +3891,52 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp // ----- +// CHECK-LABEL: func.func @torch.aten.convolution$valid_padding( +// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { +// CHECK: %[[INPUT_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[INPUT_TENSOR]] : !torch.vtensor<[1,1,5,5],f32> -> tensor<1x1x5x5xf32> +// CHECK: %[[WEIGHT_CONST:.*]] = "tosa.const"() <{values = dense<-7.486820e-03> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK: %[[BIAS_CONST:.*]] = "tosa.const"() <{values = dense<0.536443591> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[STRIDE_H:.*]] = torch.constant.int 1 +// CHECK: %[[STRIDE_W:.*]] = torch.constant.int 1 +// CHECK: %[[STRIDES_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE_H]], %[[STRIDE_W]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PADDING_VAL:.*]] = torch.constant.int 0 +// CHECK: %[[PADDING_LIST:.*]] = torch.prim.ListConstruct %[[PADDING_VAL]] : (!torch.int) -> !torch.list +// CHECK: %[[DILATION_H:.*]] = torch.constant.int 1 +// CHECK: %[[DILATION_W:.*]] = torch.constant.int 1 +// CHECK: %[[DILATIONS_LIST:.*]] = torch.prim.ListConstruct %[[DILATION_H]], %[[DILATION_W]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[GROUPS_VAL:.*]] = torch.constant.bool false +// CHECK: %[[OUTPUT_PADDING_VAL:.*]] = torch.constant.int 0 +// CHECK: %[[OUTPUT_PADDING_LIST:.*]] = torch.prim.ListConstruct %[[OUTPUT_PADDING_VAL]] : (!torch.int) -> !torch.list +// CHECK: %[[CONV_DIMENSIONS:.*]] = torch.constant.int 1 +// CHECK: %[[WEIGHT_TRANSPOSED:.*]] = tosa.transpose %[[WEIGHT_CONST]] {perms = array} : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[INPUT_TRANSPOSED:.*]] = tosa.transpose %[[INPUT_BUILTIN]] {perms = array} : (tensor<1x1x5x5xf32>) -> tensor<1x5x5x1xf32> +// CHECK: %[[ZERO_BIAS_OP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[ZERO_BIAS_OP_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[CONV_RESULT_TOSA:.*]] = tosa.conv2d %[[INPUT_TRANSPOSED]], %[[WEIGHT_TRANSPOSED]], %[[BIAS_CONST]], %[[ZERO_BIAS_OP]], %[[ZERO_BIAS_OP_2]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x5x5x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x1xf32> +// CHECK: %[[OUTPUT_TRANSPOSED:.*]] = tosa.transpose %[[CONV_RESULT_TOSA]] {perms = array} : (tensor<1x5x5x1xf32>) -> tensor<1x1x5x5xf32> +// CHECK: %[[OUTPUT_TENSOR:.*]] = torch_c.from_builtin_tensor %[[OUTPUT_TRANSPOSED]] : tensor<1x1x5x5xf32> -> !torch.vtensor<[1,1,5,5],f32> +// CHECK: return %[[OUTPUT_TENSOR]] : !torch.vtensor<[1,1,5,5],f32> +func.func @torch.aten.convolution$valid_padding(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { + %0 = torch.vtensor.literal(dense<-7.486820e-03> : tensor<1x1x1x1xf32>) : !torch.vtensor<[1,1,1,1],f32> + %1 = torch.vtensor.literal(dense<0.536443591> : tensor<1xf32>) : !torch.vtensor<[1],f32> + %int1 = torch.constant.int 1 + %int1_0 = torch.constant.int 1 + %2 = torch.prim.ListConstruct %int1, %int1_0 : (!torch.int, !torch.int) -> !torch.list + %int0 = torch.constant.int 0 + %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %int1_1 = torch.constant.int 1 + %int1_2 = torch.constant.int 1 + %4 = torch.prim.ListConstruct %int1_1, %int1_2 : (!torch.int, !torch.int) -> !torch.list + %false = torch.constant.bool false + %int0_3 = torch.constant.int 0 + %5 = torch.prim.ListConstruct %int0_3 : (!torch.int) -> !torch.list + %int1_4 = torch.constant.int 1 + %6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %false, %5, %int1_4 : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + return %6 : !torch.vtensor<[1,1,5,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input( // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1,56,56],f32>) -> !torch.vtensor<[1,1,27,27],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,56,56],f32> -> tensor<1x1x56x56xf32>