Skip to content

Improve lowering of backward convolution to linalg #4355

@zjgarvey

Description

@zjgarvey

I'm making this issue to track some steps for improving the backward conv lowering to align with some of the custom IR generation we have been experimenting with in iree-turbine (BOO).

1. Write a direct lowering from torch-to-linalg for torch.aten.convolution_backward

This pattern should generate linalg.generic ops directly. The primary reason being that we can handle assigning reduction/parallel dims directly, instead of introducing several transposes, expand/collapse ops, etc. all to work around the fact that torch.aten.convolution is always channels-first, and backward convs (both for weight backward and input backward), are never in NCHW format. This is particularly important for grouped backward convolutions, which require reassociating the groups from forward conv channel dims to whatever relevant dim is being treated as the "channel" reduction dim in the backward conv.

Essentially, the pattern should do:

  • identify which inputs require gradient from backward_mask.
  • expand shape (to expand groups if present). If G=groups, then the expanded grad_output should of shape N x G x F/G x <output_spatial_dims>. - If grouped and weight backward is required, also expand the input tensor to N x G x C/G x <input_spatial_dims>.
  • if grouped and input backward is required, also expand the weight tensor to G x F/G x C/G x <weight_spatial_dims>.
  • For weight backward calculation: This is essentially a regular convolution between the padded input and grad_output (as the weight), except the "channel" reduction is happening over the batch dim N, and the output of this modified forward conv is, unfortunately, not necessarily the same shape as the weight when stride !=1. Writing this as a generic op would allow initializing the correct output shape (should be the same shape as the weight tensor). Otherwise, we would need to manually extract a subview from the output of a convolution to get the correct shape.
  • For input backward calculation: The non-unit weight spatial dims must be flipped. If stride != 1, we are currently inserting the grad_output into a zero init tensor of shape N x G x F/G x < D_in + 2*padding - dilation*(K-1) >, then doing a convolution between modified_grad_output and weight_flip where the strides are 1, the padding is dilation*(K-1) - padding, the dilation is the dilation of the forward conv, and the reduction "channel" dim is F/G.
  • For bias backward calculation: Do a sum reduction over every dim besides F in grad_output to get a rank-1 tensor of size F.
    -For weight/input backward collapse back down group dim if present.

2. Don't decompose torch.aten.convolution_backward in relevant pass pipelines

Namely, https://github.com/iree-org/iree/blob/06b8d14e7f66a95732b77b130d69336f42005f8a/compiler/plugins/input/Torch/InputConversion/Passes.h#L22-L28

But also,

BACKEND_LEGAL_OPS = {
OutputType.LINALG_ON_TENSORS: [
"aten.adaptive_max_pool2d",
],
}
for local e2e testing in torch-mlir.

3. Handle transposes?

This may need to be managed by the lower-level compiler, but BOO is currently detecting channels-last format, and directly working this into the IR generation for the generic ops.

The primary concern after implementing the first two steps is that fusing the transposes from a pattern like:
permute(NHWC -> NCHW) -> conv_backward -> permute(NCHW -> NHWC)

is a bit difficult, primarily for backward data, when stride != 1 or groups !=1. E.g., for backward data, having the above pattern in torch-mlir might generate a channels-last func op through (e.g., Fusili/boo) which does:

Grad output preprocessing:
permute(NHWF -> NFHW) -> expand(NFHW -> NGF'HW) -> insert_slice(spread by stride along H,W) -> modified_grad_output
Weight pre-processing:
permute(FC'KhKw -> FKhKwC') -> flip(Kh, Kw) -> weight_flipped
Conv + result:
(modified_grad_output, weight_flipped) -> generic_conv_like(output shape = NGC'HW) -> collapse (NGC'HW-> NCHW) -> permute(NCHW -> NHWC)

Meanwhile, BOO currently generates the following, where the channels-last format is present throughout the IR, and doesn't have any explicit transposes. All of the affine maps, iterator_types, etc. are customized to do the reduction over the correct dim without introducing transposes directly.

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d6, d2 + d7, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d5, d6, d7, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
module @module {
  util.func public @conv_2d_bfloat16_input_backward_16x48x32x288_nhwc_288x3x3x96_fhwc_nhwf_2x2s_1x1p_1x1d_3g$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant 0.000000e+00 : bf16
    %c2 = arith.constant 2 : index
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<16x24x16x288xbf16>
    %1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<288x3x3x96xbf16>
    // Expand groups
    %expanded = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [16, 24, 16, 3, 96] : tensor<16x24x16x288xbf16> into tensor<16x24x16x3x96xbf16>
    %expanded_1 = tensor.expand_shape %1 [[0, 1], [2], [3], [4]] output_shape [3, 96, 3, 3, 96] : tensor<288x3x3x96xbf16> into tensor<3x96x3x3x96xbf16>
    // flip weight tensor
    %2 = tensor.empty() : tensor<3x96x3x3x96xbf16>
    %3 = linalg.fill ins(%cst : bf16) outs(%2 : tensor<3x96x3x3x96xbf16>) -> tensor<3x96x3x3x96xbf16>
    %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1 : tensor<3x96x3x3x96xbf16>) outs(%3 : tensor<3x96x3x3x96xbf16>) {
    ^bb0(%in: bf16, %out: bf16):
      %14 = linalg.index 0 : index
      %15 = linalg.index 1 : index
      %16 = linalg.index 2 : index
      %17 = linalg.index 3 : index
      %18 = linalg.index 4 : index
      %19 = arith.subi %c2, %16 : index
      %20 = arith.subi %c2, %17 : index
      %extracted = tensor.extract %expanded_1[%14, %15, %19, %20, %18] : tensor<3x96x3x3x96xbf16>
      linalg.yield %extracted : bf16
    } -> tensor<3x96x3x3x96xbf16>
    # modify grad_output tensor
    %5 = tensor.empty() : tensor<16x50x34x3x96xbf16>
    %6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<16x50x34x3x96xbf16>) -> tensor<16x50x34x3x96xbf16>
    %inserted_slice = tensor.insert_slice %expanded into %6[0, 1, 1, 0, 0] [16, 24, 16, 3, 96] [1, 2, 2, 1, 1] : tensor<16x24x16x3x96xbf16> into tensor<16x50x34x3x96xbf16>
    // convolution with "channel" reduction over "F/G" dim, which is "d5" in the affine maps.
    %7 = tensor.empty() : tensor<16x48x32x3x96xf32>
    %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<16x48x32x3x96xf32>) -> tensor<16x48x32x3x96xf32>
    %9 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %4 : tensor<16x50x34x3x96xbf16>, tensor<3x96x3x3x96xbf16>) outs(%8 : tensor<16x48x32x3x96xf32>) {
    ^bb0(%in: bf16, %in_2: bf16, %out: f32):
      %14 = arith.extf %in : bf16 to f32
      %15 = arith.extf %in_2 : bf16 to f32
      %16 = arith.mulf %14, %15 : f32
      %17 = arith.addf %out, %16 : f32
      linalg.yield %17 : f32
    } -> tensor<16x48x32x3x96xf32>
    %10 = tensor.empty() : tensor<16x48x32x3x96xbf16>
    %11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<16x48x32x3x96xf32>) outs(%10 : tensor<16x48x32x3x96xbf16>) {
    ^bb0(%in: f32, %out: bf16):
      %14 = arith.truncf %in : f32 to bf16
      linalg.yield %14 : bf16
    } -> tensor<16x48x32x3x96xbf16>
    // collapse groups into "C/G" dim.
    %collapsed = tensor.collapse_shape %11 [[0], [1], [2], [3, 4]] : tensor<16x48x32x3x96xbf16> into tensor<16x48x32x288xbf16>
    %12 = hal.tensor.barrier join(%collapsed : tensor<16x48x32x288xbf16>) => %arg4 : !hal.fence
    %13 = hal.tensor.export %12 : tensor<16x48x32x288xbf16> -> !hal.buffer_view
    util.return %13 : !hal.buffer_view
  }
  util.func public @conv_2d_bfloat16_input_backward_16x48x32x288_nhwc_288x3x3x96_fhwc_nhwf_2x2s_1x1p_1x1d_3g(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @conv_2d_bfloat16_input_backward_16x48x32x288_nhwc_288x3x3x96_fhwc_nhwf_2x2s_1x1p_1x1d_3g$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1 : !hal.buffer_view
  }
}

I'm not sure how viable it would be to have IREE handle sinking transposes past expand_shape and insert_slice. We could perhaps consider adding a special convolution op in torch-mlir which can explicitly record layouts, and a pattern to convert permute(NHWC -> NCHW) -> conv/backward_conv -> permute(NCHW -> NHWC) to special_conv/special_backward_conv(input_layout=NCHW, ...). This way a lowering for convolution backward can handle the tensor layouts directly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions