Skip to content

Commit becb760

Browse files
committed
webgpu: add FormatTransform kernel and tests
- add the WebGPU FormatTransform kernel, headers, and WGSL template, supporting Plain <-> nChw4c/ABcd16a4b conversions - register the FormatTransform schema in the internal NHWC domain with padding-aware shape inference - hook the kernel into the WebGPU execution provider and add WebGPU tests covering both blocked formats, padding cases, and round trips
1 parent 977efe4 commit becb760

File tree

6 files changed

+729
-0
lines changed

6 files changed

+729
-0
lines changed

onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,84 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMES
163163
REGISTER_NHWC_SCHEMA_FROM_MSDOMAIN(fn, QLinearAveragePool, 1);
164164
REGISTER_NHWC_SCHEMA_FROM_MSDOMAIN(fn, QLinearConvTranspose, 1);
165165

166+
// FormatTransform operator for OneDNN blocked format support
167+
fn(std::move(::ONNX_NAMESPACE::OpSchema()
168+
.SetName("FormatTransform")
169+
.SetDomain(onnxruntime::kMSInternalNHWCDomain)
170+
.SinceVersion(1)
171+
.SetDoc("Transform tensor between plain (NCHW) and OneDNN blocked formats (nChw4c, ABcd16a4b).")
172+
.Attr("src_format", "Source format: Plain, nChw4c, or ABcd16a4b",
173+
ONNX_NAMESPACE::AttributeProto::STRING)
174+
.Attr("dst_format", "Destination format: Plain, nChw4c, or ABcd16a4b",
175+
ONNX_NAMESPACE::AttributeProto::STRING)
176+
.Input(0, "X", "Input tensor", "T")
177+
.Output(0, "Y", "Output tensor with transformed layout", "T")
178+
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
179+
"Constrain input and output types to floating-point tensors.")
180+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
181+
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
182+
if (!ONNX_NAMESPACE::hasInputShape(ctx, 0)) {
183+
return;
184+
}
185+
186+
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
187+
if (input_shape.dim_size() != 4) {
188+
fail_shape_inference("FormatTransform requires 4D input tensor (NCHW)");
189+
}
190+
191+
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
192+
output_shape->clear_dim(); // Clear any existing dimensions before writing
193+
194+
// Get destination format attribute
195+
std::string dst_format;
196+
if (ctx.getAttribute("dst_format") != nullptr) {
197+
dst_format = ctx.getAttribute("dst_format")->s();
198+
}
199+
200+
// Calculate output shape with padding if needed
201+
if (dst_format == "nChw4c") {
202+
// Pad channels (dimension 1) to multiple of 4
203+
*output_shape->add_dim() = input_shape.dim(0);
204+
205+
if (input_shape.dim(1).has_dim_value()) {
206+
int64_t C = input_shape.dim(1).dim_value();
207+
int64_t padded_C = ((C + 3) / 4) * 4;
208+
output_shape->add_dim()->set_dim_value(padded_C);
209+
} else {
210+
// Dynamic channel dimension - can't compute padding statically
211+
output_shape->add_dim()->set_dim_param(input_shape.dim(1).dim_param());
212+
}
213+
214+
*output_shape->add_dim() = input_shape.dim(2);
215+
*output_shape->add_dim() = input_shape.dim(3);
216+
} else if (dst_format == "ABcd16a4b") {
217+
// Pad N (dimension 0) to multiple of 16 and C (dimension 1) to multiple of 4
218+
if (input_shape.dim(0).has_dim_value()) {
219+
int64_t N = input_shape.dim(0).dim_value();
220+
int64_t padded_N = ((N + 15) / 16) * 16;
221+
output_shape->add_dim()->set_dim_value(padded_N);
222+
} else {
223+
output_shape->add_dim()->set_dim_param(input_shape.dim(0).dim_param());
224+
}
225+
226+
if (input_shape.dim(1).has_dim_value()) {
227+
int64_t C = input_shape.dim(1).dim_value();
228+
int64_t padded_C = ((C + 3) / 4) * 4;
229+
output_shape->add_dim()->set_dim_value(padded_C);
230+
} else {
231+
output_shape->add_dim()->set_dim_param(input_shape.dim(1).dim_param());
232+
}
233+
234+
*output_shape->add_dim() = input_shape.dim(2);
235+
*output_shape->add_dim() = input_shape.dim(3);
236+
} else {
237+
// Plain or other formats: no padding needed
238+
for (int i = 0; i < input_shape.dim_size(); ++i) {
239+
*output_shape->add_dim() = input_shape.dim(i);
240+
}
241+
}
242+
})));
243+
166244
// not all schema are registered here. For part of layout insensitive ops
167245
// we will use onnx schema directly, for others, like fused-node/qdq-group
168246
// we may leverage internal schema or create on the fly.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/vendor/intel/contrib/format_transform.h"
5+
#include "core/providers/webgpu/shader_helper.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
#include "core/providers/webgpu/string_macros.h"
8+
#include "core/common/narrow.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
namespace intel {
13+
14+
namespace {
15+
std::string GetFormatName(BlockedFormat format) {
16+
switch (format) {
17+
case BlockedFormat::Plain:
18+
return "Plain";
19+
case BlockedFormat::nChw4c:
20+
return "nChw4c";
21+
case BlockedFormat::ABcd16a4b:
22+
return "ABcd16a4b";
23+
default:
24+
return "Unknown";
25+
}
26+
}
27+
28+
BlockedFormat ParseFormat(const std::string& format_str) {
29+
if (format_str == "Plain" || format_str == "NCHW") {
30+
return BlockedFormat::Plain;
31+
} else if (format_str == "nChw4c") {
32+
return BlockedFormat::nChw4c;
33+
} else if (format_str == "ABcd16a4b") {
34+
return BlockedFormat::ABcd16a4b;
35+
} else {
36+
ORT_THROW("Unsupported format: ", format_str);
37+
}
38+
}
39+
} // namespace
40+
41+
FormatTransformProgram::FormatTransformProgram(BlockedFormat src_format, BlockedFormat dst_format,
42+
const TensorShape& input_shape, const TensorShape& output_shape)
43+
: Program{"FormatTransform"},
44+
src_format_(src_format),
45+
dst_format_(dst_format),
46+
input_shape_(input_shape),
47+
output_shape_(output_shape) {
48+
}
49+
50+
Status FormatTransformProgram::GenerateShaderCode(ShaderHelper& sh) const {
51+
const auto& input = sh.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
52+
const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
53+
54+
auto rank = input_shape_.NumDimensions();
55+
ORT_RETURN_IF_NOT(rank == 4, "FormatTransform currently only supports 4D tensors (NCHW)");
56+
57+
const int src_format_val = static_cast<int>(src_format_);
58+
const int dst_format_val = static_cast<int>(dst_format_);
59+
60+
return WGSL_TEMPLATE_APPLY(sh, "vendor/intel/contrib/format_transform.wgsl.template",
61+
WGSL_TEMPLATE_PARAMETER(dst_format, dst_format_val),
62+
WGSL_TEMPLATE_PARAMETER(src_format, src_format_val),
63+
WGSL_TEMPLATE_VARIABLE(input, input),
64+
WGSL_TEMPLATE_VARIABLE(output, output));
65+
}
66+
67+
FormatTransform::FormatTransform(const OpKernelInfo& info)
68+
: WebGpuKernel(info) {
69+
std::string src_format_str = info.GetAttrOrDefault<std::string>("src_format", "Plain");
70+
std::string dst_format_str = info.GetAttrOrDefault<std::string>("dst_format", "Plain");
71+
72+
src_format_ = ParseFormat(src_format_str);
73+
dst_format_ = ParseFormat(dst_format_str);
74+
}
75+
76+
Status FormatTransform::ComputeInternal(ComputeContext& context) const {
77+
const auto* input = context.Input<Tensor>(0);
78+
const auto& input_shape = input->Shape();
79+
80+
ORT_RETURN_IF_NOT(input_shape.NumDimensions() == 4, "FormatTransform only supports 4D tensors");
81+
82+
// Calculate output shape with padding if needed for blocked formats
83+
TensorShape output_shape = input_shape;
84+
85+
if (dst_format_ == BlockedFormat::nChw4c) {
86+
// For nChw4c, pad channels to multiple of 4
87+
int64_t C = input_shape[1];
88+
int64_t padded_C = ((C + 3) / 4) * 4; // Round up to multiple of 4
89+
output_shape = TensorShape({input_shape[0], padded_C, input_shape[2], input_shape[3]});
90+
} else if (dst_format_ == BlockedFormat::ABcd16a4b) {
91+
// For ABcd16a4b, pad N to multiple of 16 and C to multiple of 4
92+
int64_t N = input_shape[0];
93+
int64_t C = input_shape[1];
94+
int64_t padded_N = ((N + 15) / 16) * 16; // Round up to multiple of 16
95+
int64_t padded_C = ((C + 3) / 4) * 4; // Round up to multiple of 4
96+
output_shape = TensorShape({padded_N, padded_C, input_shape[2], input_shape[3]});
97+
}
98+
// For Plain output format, no padding needed (output_shape remains input_shape)
99+
100+
auto* output = context.Output(0, output_shape);
101+
102+
FormatTransformProgram program{src_format_, dst_format_, input_shape, output_shape};
103+
program
104+
.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank})
105+
.AddOutput({output, ProgramTensorMetadataDependency::None})
106+
.SetDispatchGroupSize((static_cast<uint32_t>(output_shape.Size()) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
107+
.CacheHint(static_cast<int>(src_format_), static_cast<int>(dst_format_))
108+
.AddUniformVariables({{static_cast<uint32_t>(output_shape.Size())}});
109+
110+
return context.RunProgram(program);
111+
}
112+
113+
} // namespace intel
114+
} // namespace webgpu
115+
} // namespace onnxruntime
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
namespace intel {
12+
13+
// OneDNN blocked format types
14+
// IMPORTANT: keep enum order/numeric values in sync with format_transform.wgsl.template format macro values.
15+
enum class BlockedFormat {
16+
Plain, // Standard NCHW layout
17+
nChw4c, // Blocked with 4-channel blocks
18+
ABcd16a4b, // 2D blocked format: blocks A dimension with 16, B dimension with 4
19+
};
20+
21+
class FormatTransformProgram final : public Program<FormatTransformProgram> {
22+
public:
23+
FormatTransformProgram(BlockedFormat src_format, BlockedFormat dst_format,
24+
const TensorShape& input_shape, const TensorShape& output_shape);
25+
26+
Status GenerateShaderCode(ShaderHelper& sh) const override;
27+
28+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
29+
30+
private:
31+
BlockedFormat src_format_;
32+
BlockedFormat dst_format_;
33+
TensorShape input_shape_;
34+
TensorShape output_shape_;
35+
};
36+
37+
// Internal operator for format transformation between plain and blocked formats
38+
class FormatTransform final : public WebGpuKernel {
39+
public:
40+
FormatTransform(const OpKernelInfo& info);
41+
Status ComputeInternal(ComputeContext& context) const override;
42+
43+
private:
44+
BlockedFormat src_format_;
45+
BlockedFormat dst_format_;
46+
};
47+
48+
} // namespace intel
49+
} // namespace webgpu
50+
} // namespace onnxruntime
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// Format transformation shader for OneDNN blocked layouts
5+
// Supports Plain (NCHW), nChw4c, and ABcd16a4b formats
6+
7+
#define FORMAT_PLAIN 0
8+
#define FORMAT_NCHW4C 1
9+
#define FORMAT_ABCD16A4B 2
10+
11+
#use getElementAt
12+
#use .getByOffset .setByOffset
13+
#use guardAgainstOutOfBoundsWorkgroupSizes
14+
15+
#param src_format
16+
#param dst_format
17+
18+
$MAIN {
19+
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size);
20+
21+
// Get shapes from uniforms
22+
let N = uniforms.input_shape[0];
23+
let C = uniforms.input_shape[1];
24+
let H = uniforms.input_shape[2];
25+
let W = uniforms.input_shape[3];
26+
27+
let out_N = uniforms.output_shape[0];
28+
let out_C = uniforms.output_shape[1];
29+
let out_H = uniforms.output_shape[2];
30+
let out_W = uniforms.output_shape[3];
31+
32+
// Calculate NCHW indices from output global index
33+
let n = global_idx / (out_C * out_H * out_W);
34+
let chw_idx = global_idx % (out_C * out_H * out_W);
35+
let c = chw_idx / (out_H * out_W);
36+
let hw_idx = chw_idx % (out_H * out_W);
37+
let h = hw_idx / out_W;
38+
let w = hw_idx % out_W;
39+
40+
var output_idx: u32 = 0u;
41+
#if dst_format == FORMAT_PLAIN
42+
// Plain format: NCHW
43+
output_idx = n * C * H * W + c * H * W + h * W + w;
44+
#elif dst_format == FORMAT_NCHW4C
45+
// nChw4c format: [N, C/4, H, W, 4]
46+
let block_size = 4u;
47+
let num_blocks = (out_C + block_size - 1u) / block_size;
48+
let block_idx = c / block_size;
49+
let c_in_block = c % block_size;
50+
output_idx = n * num_blocks * out_H * out_W * block_size +
51+
block_idx * out_H * out_W * block_size +
52+
h * out_W * block_size +
53+
w * block_size +
54+
c_in_block;
55+
#elif dst_format == FORMAT_ABCD16A4B
56+
// ABcd16a4b format: [N/16, C/4, H, W, 16, 4]
57+
let a_block = 16u;
58+
let b_block = 4u;
59+
let num_a_blocks = (out_N + a_block - 1u) / a_block;
60+
let num_b_blocks = (out_C + b_block - 1u) / b_block;
61+
let a_block_idx = n / a_block;
62+
let n_in_block = n % a_block;
63+
let b_block_idx = c / b_block;
64+
let c_in_block = c % b_block;
65+
output_idx = a_block_idx * num_b_blocks * out_H * out_W * a_block * b_block +
66+
b_block_idx * out_H * out_W * a_block * b_block +
67+
h * out_W * a_block * b_block +
68+
w * a_block * b_block +
69+
n_in_block * b_block +
70+
c_in_block;
71+
#endif
72+
73+
// Check if this output position is within input bounds
74+
if (n >= N || c >= C || h >= H || w >= W) {
75+
// Padding area - fill with zero
76+
output.setByOffset(output_idx, output_value_t(0));
77+
} else {
78+
// Within input bounds - transform data
79+
// Calculate input index
80+
#if src_format == FORMAT_PLAIN
81+
// Plain format: NCHW
82+
let input_idx = n * C * H * W + c * H * W + h * W + w;
83+
#elif src_format == FORMAT_NCHW4C
84+
// nChw4c format: [N, C/4, H, W, 4]
85+
let block_size = 4u;
86+
let num_blocks = (C + block_size - 1u) / block_size;
87+
let block_idx = c / block_size;
88+
let c_in_block = c % block_size;
89+
let input_idx = n * num_blocks * H * W * block_size +
90+
block_idx * H * W * block_size +
91+
h * W * block_size +
92+
w * block_size +
93+
c_in_block;
94+
#elif src_format == FORMAT_ABCD16A4B
95+
// ABcd16a4b format: [N/16, C/4, H, W, 16, 4]
96+
let a_block = 16u;
97+
let b_block = 4u;
98+
let num_a_blocks = (N + a_block - 1u) / a_block;
99+
let num_b_blocks = (C + b_block - 1u) / b_block;
100+
let a_block_idx = n / a_block;
101+
let n_in_block = n % a_block;
102+
let b_block_idx = c / b_block;
103+
let c_in_block = c % b_block;
104+
let input_idx = a_block_idx * num_b_blocks * H * W * a_block * b_block +
105+
b_block_idx * H * W * a_block * b_block +
106+
h * W * a_block * b_block +
107+
w * a_block * b_block +
108+
n_in_block * b_block +
109+
c_in_block;
110+
#endif
111+
112+
output.setByOffset(output_idx, input.getByOffset(input_idx));
113+
}
114+
} // MAIN

0 commit comments

Comments
 (0)