-
Notifications
You must be signed in to change notification settings - Fork 616
Description
I'm making this issue to track some steps for improving the backward conv lowering to align with some of the custom IR generation we have been experimenting with in iree-turbine (BOO).
1. Write a direct lowering from torch-to-linalg for torch.aten.convolution_backward
This pattern should generate linalg.generic ops directly. The primary reason being that we can handle assigning reduction/parallel dims directly, instead of introducing several transposes, expand/collapse ops, etc. all to work around the fact that torch.aten.convolution is always channels-first, and backward convs (both for weight backward and input backward), are never in NCHW format. This is particularly important for grouped backward convolutions, which require reassociating the groups from forward conv channel dims to whatever relevant dim is being treated as the "channel" reduction dim in the backward conv.
Essentially, the pattern should do:
- identify which inputs require gradient from
backward_mask. - expand shape (to expand groups if present). If
G=groups, then the expandedgrad_outputshould of shapeN x G x F/G x <output_spatial_dims>. - If grouped and weight backward is required, also expand theinputtensor toN x G x C/G x <input_spatial_dims>. - if grouped and input backward is required, also expand the
weighttensor toG x F/G x C/G x <weight_spatial_dims>. - For weight backward calculation: This is essentially a regular convolution between the padded
inputandgrad_output(as the weight), except the "channel" reduction is happening over the batch dimN, and the output of this modified forward conv is, unfortunately, not necessarily the same shape as the weight when stride !=1. Writing this as a generic op would allow initializing the correct output shape (should be the same shape as the weight tensor). Otherwise, we would need to manually extract a subview from the output of a convolution to get the correct shape. - For input backward calculation: The non-unit weight spatial dims must be flipped. If stride != 1, we are currently inserting the
grad_outputinto a zero init tensor of shapeN x G x F/G x < D_in + 2*padding - dilation*(K-1) >, then doing a convolution betweenmodified_grad_outputandweight_flipwhere the strides are 1, the padding isdilation*(K-1) - padding, the dilation is the dilation of the forward conv, and the reduction "channel" dim isF/G. - For bias backward calculation: Do a sum reduction over every dim besides
Fingrad_outputto get a rank-1 tensor of sizeF.
-For weight/input backward collapse back down group dim if present.
2. Don't decompose torch.aten.convolution_backward in relevant pass pipelines
But also,
torch-mlir/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py
Lines 23 to 27 in 8b77de9
| BACKEND_LEGAL_OPS = { | |
| OutputType.LINALG_ON_TENSORS: [ | |
| "aten.adaptive_max_pool2d", | |
| ], | |
| } |
3. Handle transposes?
This may need to be managed by the lower-level compiler, but BOO is currently detecting channels-last format, and directly working this into the IR generation for the generic ops.
The primary concern after implementing the first two steps is that fusing the transposes from a pattern like:
permute(NHWC -> NCHW) -> conv_backward -> permute(NCHW -> NHWC)
is a bit difficult, primarily for backward data, when stride != 1 or groups !=1. E.g., for backward data, having the above pattern in torch-mlir might generate a channels-last func op through (e.g., Fusili/boo) which does:
Grad output preprocessing:
permute(NHWF -> NFHW) -> expand(NFHW -> NGF'HW) -> insert_slice(spread by stride along H,W) -> modified_grad_output
Weight pre-processing:
permute(FC'KhKw -> FKhKwC') -> flip(Kh, Kw) -> weight_flipped
Conv + result:
(modified_grad_output, weight_flipped) -> generic_conv_like(output shape = NGC'HW) -> collapse (NGC'HW-> NCHW) -> permute(NCHW -> NHWC)
Meanwhile, BOO currently generates the following, where the channels-last format is present throughout the IR, and doesn't have any explicit transposes. All of the affine maps, iterator_types, etc. are customized to do the reduction over the correct dim without introducing transposes directly.
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d6, d2 + d7, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d5, d6, d7, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
module @module {
util.func public @conv_2d_bfloat16_input_backward_16x48x32x288_nhwc_288x3x3x96_fhwc_nhwf_2x2s_1x1p_1x1d_3g$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : bf16
%c2 = arith.constant 2 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<16x24x16x288xbf16>
%1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<288x3x3x96xbf16>
// Expand groups
%expanded = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [16, 24, 16, 3, 96] : tensor<16x24x16x288xbf16> into tensor<16x24x16x3x96xbf16>
%expanded_1 = tensor.expand_shape %1 [[0, 1], [2], [3], [4]] output_shape [3, 96, 3, 3, 96] : tensor<288x3x3x96xbf16> into tensor<3x96x3x3x96xbf16>
// flip weight tensor
%2 = tensor.empty() : tensor<3x96x3x3x96xbf16>
%3 = linalg.fill ins(%cst : bf16) outs(%2 : tensor<3x96x3x3x96xbf16>) -> tensor<3x96x3x3x96xbf16>
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1 : tensor<3x96x3x3x96xbf16>) outs(%3 : tensor<3x96x3x3x96xbf16>) {
^bb0(%in: bf16, %out: bf16):
%14 = linalg.index 0 : index
%15 = linalg.index 1 : index
%16 = linalg.index 2 : index
%17 = linalg.index 3 : index
%18 = linalg.index 4 : index
%19 = arith.subi %c2, %16 : index
%20 = arith.subi %c2, %17 : index
%extracted = tensor.extract %expanded_1[%14, %15, %19, %20, %18] : tensor<3x96x3x3x96xbf16>
linalg.yield %extracted : bf16
} -> tensor<3x96x3x3x96xbf16>
# modify grad_output tensor
%5 = tensor.empty() : tensor<16x50x34x3x96xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<16x50x34x3x96xbf16>) -> tensor<16x50x34x3x96xbf16>
%inserted_slice = tensor.insert_slice %expanded into %6[0, 1, 1, 0, 0] [16, 24, 16, 3, 96] [1, 2, 2, 1, 1] : tensor<16x24x16x3x96xbf16> into tensor<16x50x34x3x96xbf16>
// convolution with "channel" reduction over "F/G" dim, which is "d5" in the affine maps.
%7 = tensor.empty() : tensor<16x48x32x3x96xf32>
%8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<16x48x32x3x96xf32>) -> tensor<16x48x32x3x96xf32>
%9 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %4 : tensor<16x50x34x3x96xbf16>, tensor<3x96x3x3x96xbf16>) outs(%8 : tensor<16x48x32x3x96xf32>) {
^bb0(%in: bf16, %in_2: bf16, %out: f32):
%14 = arith.extf %in : bf16 to f32
%15 = arith.extf %in_2 : bf16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.addf %out, %16 : f32
linalg.yield %17 : f32
} -> tensor<16x48x32x3x96xf32>
%10 = tensor.empty() : tensor<16x48x32x3x96xbf16>
%11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<16x48x32x3x96xf32>) outs(%10 : tensor<16x48x32x3x96xbf16>) {
^bb0(%in: f32, %out: bf16):
%14 = arith.truncf %in : f32 to bf16
linalg.yield %14 : bf16
} -> tensor<16x48x32x3x96xbf16>
// collapse groups into "C/G" dim.
%collapsed = tensor.collapse_shape %11 [[0], [1], [2], [3, 4]] : tensor<16x48x32x3x96xbf16> into tensor<16x48x32x288xbf16>
%12 = hal.tensor.barrier join(%collapsed : tensor<16x48x32x288xbf16>) => %arg4 : !hal.fence
%13 = hal.tensor.export %12 : tensor<16x48x32x288xbf16> -> !hal.buffer_view
util.return %13 : !hal.buffer_view
}
util.func public @conv_2d_bfloat16_input_backward_16x48x32x288_nhwc_288x3x3x96_fhwc_nhwf_2x2s_1x1p_1x1d_3g(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%0 = util.null : !hal.fence
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
%1 = util.call @conv_2d_bfloat16_input_backward_16x48x32x288_nhwc_288x3x3x96_fhwc_nhwf_2x2s_1x1p_1x1d_3g$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.return %1 : !hal.buffer_view
}
}I'm not sure how viable it would be to have IREE handle sinking transposes past expand_shape and insert_slice. We could perhaps consider adding a special convolution op in torch-mlir which can explicitly record layouts, and a pattern to convert permute(NHWC -> NCHW) -> conv/backward_conv -> permute(NCHW -> NHWC) to special_conv/special_backward_conv(input_layout=NCHW, ...). This way a lowering for convolution backward can handle the tensor layouts directly.