Skip to content

Slow attention kernel on navi48 #2096

@pfultz2

Description

@pfultz2

This attention kernel runs faster when running as unfused than fused in migraphx:

p = migraphx.program()
m = p.get_main_module()
x_0 = m.add_literal(migraphx.create_argument(migraphx.shape(type="half_type", lens=[1]), [0.210205]))
x_1 = m.add_literal(migraphx.create_argument(migraphx.shape(type="half_type", lens=[1]), [0.210205]))
p_x0 = m.add_parameter("x0", migraphx.shape(type="half_type", lens=[1, 9216, 1536]))
x_3 = m.add_instruction(migraphx.op("slice", axes=[2], starts=[0], ends=[512]), [p_x0]) # migraphx.shape(type="half_type", lens=[1, 9216, 512], strides=[14155776, 1536, 1])
x_4 = m.add_instruction(migraphx.op("unsqueeze", axes=[2]), [x_3]) # migraphx.shape(type="half_type", lens=[1, 9216, 1, 512], strides=[14155776, 1536, 512, 1])
x_5 = m.add_instruction(migraphx.op("transpose", permutation=[0,2,1,3]), [x_4]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 512], strides=[14155776, 512, 1536, 1])
x_6 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,9216,512]), [x_1]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 512], strides=[0, 0, 0, 0])
x_7 = m.add_instruction(migraphx.op("mul"), [x_5, x_6]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 512])
x_8 = m.add_instruction(migraphx.op("slice", axes=[2], starts=[512], ends=[1024]), [p_x0]) # migraphx.shape(type="half_type", lens=[1, 9216, 512], strides=[14155776, 1536, 1])
x_9 = m.add_instruction(migraphx.op("unsqueeze", axes=[2]), [x_8]) # migraphx.shape(type="half_type", lens=[1, 9216, 1, 512], strides=[14155776, 1536, 512, 1])
x_10 = m.add_instruction(migraphx.op("transpose", permutation=[0,2,3,1]), [x_9]) # migraphx.shape(type="half_type", lens=[1, 1, 512, 9216], strides=[14155776, 512, 1, 1536])
x_11 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,512,9216]), [x_0]) # migraphx.shape(type="half_type", lens=[1, 1, 512, 9216], strides=[0, 0, 0, 0])
x_12 = m.add_instruction(migraphx.op("mul"), [x_10, x_11]) # migraphx.shape(type="half_type", lens=[1, 1, 512, 9216], strides=[4718592, 512, 1, 512])
x_13 = m.add_instruction(migraphx.op("slice", axes=[2], starts=[1024], ends=[1536]), [p_x0]) # migraphx.shape(type="half_type", lens=[1, 9216, 512], strides=[14155776, 1536, 1])
x_14 = m.add_instruction(migraphx.op("unsqueeze", axes=[2]), [x_13]) # migraphx.shape(type="half_type", lens=[1, 9216, 1, 512], strides=[14155776, 1536, 512, 1])
x_15 = m.add_instruction(migraphx.op("transpose", permutation=[0,2,1,3]), [x_14]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 512], strides=[14155776, 512, 1536, 1])
x_16 = m.add_instruction(migraphx.op("dot"), [x_7, x_12]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 9216])
x_17 = m.add_instruction(migraphx.op("convert", target_type=2), [x_16]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 9216])
x_18 = m.add_instruction(migraphx.op("reduce_max", axes=[3]), [x_17]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 1])
x_19 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,9216,9216]), [x_18]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 9216], strides=[9216, 9216, 1, 0])
x_20 = m.add_instruction(migraphx.op("sub"), [x_17, x_19]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 9216])
x_21 = m.add_instruction(migraphx.op("exp"), [x_20]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 9216])
x_22 = m.add_instruction(migraphx.op("reduce_sum", axes=[3]), [x_21]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 1])
x_23 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,9216,9216]), [x_22]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 9216], strides=[9216, 9216, 1, 0])
x_24 = m.add_instruction(migraphx.op("div"), [x_21, x_23]) # migraphx.shape(type="float_type", lens=[1, 1, 9216, 9216])
x_25 = m.add_instruction(migraphx.op("convert"), [x_24]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 9216])
x_26 = m.add_instruction(migraphx.op("dot"), [x_25, x_15]) # migraphx.shape(type="half_type", lens=[1, 1, 9216, 512])
m.add_return([x_26])

With attention it runs at 5.10197ms and without attention it runs at 2.80401ms. This is the mlir module:

module {
  func.func @mlir_slice_unsqueeze_transpose_mul_slice_unsqueeze_transpose_mul_slice_unsqueeze_transpose_dot_convert_reshape_reduce_max_reshape_sub_exp_reshape_reduce_sum_reshape_div_convert_dot(%arg0: !migraphx.shaped<1x9216x1536xf16, 14155776x1536x1>) -> !migraphx.shaped<1x1x9216x512xf16, 4718592x4718592x512x1> attributes {arch = "gfx1201", kernel = "mixr", num_cu = 32 : i64} {
    %0 = migraphx.literal(dense<2.102050e-01> : tensor<1xf16>) : <1xf16, 0>
    %1 = migraphx.literal(dense<2.102050e-01> : tensor<1xf16>) : <1xf16, 0>
    %2 = migraphx.slice %arg0 {axes = [2], ends = [512], starts = [0]} : <1x9216x1536xf16, 14155776x1536x1> -> <1x9216x512xf16, 14155776x1536x1>
    %3 = migraphx.reshape %2 {dims = [1, 9216, 1, 512]} : <1x9216x512xf16, 14155776x1536x1> -> <1x9216x1x512xf16, 14155776x1536x512x1>
    %4 = migraphx.transpose %3 {permutation = [0, 2, 1, 3]} : <1x9216x1x512xf16, 14155776x1536x512x1> -> <1x1x9216x512xf16, 14155776x512x1536x1>
    %5 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 9216, 512]} : <1xf16, 0> -> <1x1x9216x512xf16, 0x0x0x0>
    %6 = migraphx.mul %4, %5 : <1x1x9216x512xf16, 14155776x512x1536x1>, <1x1x9216x512xf16, 0x0x0x0> -> <1x1x9216x512xf16, 4718592x512x512x1>
    %7 = migraphx.slice %arg0 {axes = [2], ends = [1024], starts = [512]} : <1x9216x1536xf16, 14155776x1536x1> -> <1x9216x512xf16, 14155776x1536x1>
    %8 = migraphx.reshape %7 {dims = [1, 9216, 1, 512]} : <1x9216x512xf16, 14155776x1536x1> -> <1x9216x1x512xf16, 14155776x1536x512x1>
    %9 = migraphx.transpose %8 {permutation = [0, 2, 3, 1]} : <1x9216x1x512xf16, 14155776x1536x512x1> -> <1x1x512x9216xf16, 14155776x512x1x1536>
    %10 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 1, 512, 9216]} : <1xf16, 0> -> <1x1x512x9216xf16, 0x0x0x0>
    %11 = migraphx.mul %9, %10 : <1x1x512x9216xf16, 14155776x512x1x1536>, <1x1x512x9216xf16, 0x0x0x0> -> <1x1x512x9216xf16, 4718592x512x1x512>
    %12 = migraphx.slice %arg0 {axes = [2], ends = [1536], starts = [1024]} : <1x9216x1536xf16, 14155776x1536x1> -> <1x9216x512xf16, 14155776x1536x1>
    %13 = migraphx.reshape %12 {dims = [1, 9216, 1, 512]} : <1x9216x512xf16, 14155776x1536x1> -> <1x9216x1x512xf16, 14155776x1536x512x1>
    %14 = migraphx.transpose %13 {permutation = [0, 2, 1, 3]} : <1x9216x1x512xf16, 14155776x1536x512x1> -> <1x1x9216x512xf16, 14155776x512x1536x1>
    %15 = migraphx.dot %6, %11 : <1x1x9216x512xf16, 4718592x512x512x1>, <1x1x512x9216xf16, 4718592x512x1x512> -> <1x1x9216x9216xf16, 84934656x84934656x9216x1>
    %16 = migraphx.convert %15 {target_type = 2 : i64} : <1x1x9216x9216xf16, 84934656x84934656x9216x1> to <1x1x9216x9216xf32, 84934656x84934656x9216x1>
    %17 = migraphx.reshape %16 {dims = [1, 1, 9216, 9216]} : <1x1x9216x9216xf32, 84934656x84934656x9216x1> -> <1x1x9216x9216xf32, 84934656x84934656x9216x1>
    %18 = migraphx.reduce_max %17 {axes = [3]} : <1x1x9216x9216xf32, 84934656x84934656x9216x1> -> <1x1x9216x1xf32, 9216x9216x1x1>
    %19 = migraphx.reshape %18 {dims = [1, 1, 9216, 1]} : <1x1x9216x1xf32, 9216x9216x1x1> -> <1x1x9216x1xf32, 9216x9216x1x1>
    %20 = migraphx.multibroadcast %19 {out_dyn_dims = [], out_lens = [1, 1, 9216, 9216]} : <1x1x9216x1xf32, 9216x9216x1x1> -> <1x1x9216x9216xf32, 9216x9216x1x0>
    %21 = migraphx.sub %16, %20 : <1x1x9216x9216xf32, 84934656x84934656x9216x1>, <1x1x9216x9216xf32, 9216x9216x1x0> -> <1x1x9216x9216xf32, 84934656x84934656x9216x1>
    %22 = migraphx.exp %21 : <1x1x9216x9216xf32, 84934656x84934656x9216x1> -> <1x1x9216x9216xf32, 84934656x84934656x9216x1>
    %23 = migraphx.reshape %22 {dims = [1, 1, 9216, 9216]} : <1x1x9216x9216xf32, 84934656x84934656x9216x1> -> <1x1x9216x9216xf32, 84934656x84934656x9216x1>
    %24 = migraphx.reduce_sum %23 {axes = [3]} : <1x1x9216x9216xf32, 84934656x84934656x9216x1> -> <1x1x9216x1xf32, 9216x9216x1x1>
    %25 = migraphx.reshape %24 {dims = [1, 1, 9216, 1]} : <1x1x9216x1xf32, 9216x9216x1x1> -> <1x1x9216x1xf32, 9216x9216x1x1>
    %26 = migraphx.multibroadcast %25 {out_dyn_dims = [], out_lens = [1, 1, 9216, 9216]} : <1x1x9216x1xf32, 9216x9216x1x1> -> <1x1x9216x9216xf32, 9216x9216x1x0>
    %27 = migraphx.div %22, %26 : <1x1x9216x9216xf32, 84934656x84934656x9216x1>, <1x1x9216x9216xf32, 9216x9216x1x0> -> <1x1x9216x9216xf32, 84934656x84934656x9216x1>
    %28 = migraphx.convert %27 {target_type = 1 : i64} : <1x1x9216x9216xf32, 84934656x84934656x9216x1> to <1x1x9216x9216xf16, 84934656x84934656x9216x1>
    %29 = migraphx.dot %28, %14 : <1x1x9216x9216xf16, 84934656x84934656x9216x1>, <1x1x9216x512xf16, 14155776x512x1536x1> -> <1x1x9216x512xf16, 4718592x4718592x512x1>
    return %29 : !migraphx.shaped<1x1x9216x512xf16, 4718592x4718592x512x1>
  }
}

The winning attention config is attn:v2:64,64,64,8,32,32,4,1,1,2,1(which uses exhaustive tune). The winning configs for gemms are: v3:128,256,4,64,64,8,1,1,2,1,1 and v3:128,128,8,32,128,8,1,1,2,1,1.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions