Skip to content

Commit c572732

Browse files
CPU: uni_planar_convolution optimization
1 parent f9e363f commit c572732

File tree

2 files changed

+25
-30
lines changed

2 files changed

+25
-30
lines changed

src/cpu/x64/jit_uni_planar_conv_kernel_f32.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ status_t jit_uni_planar_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp
723723
const auto &p = attr.post_ops_;
724724
jcp.with_sum = p.find(primitive_kind::sum) != -1;
725725

726-
const int simd_w = isa == avx512_core ? 16 : 8;
726+
const int simd_w = isa == avx512_core ? 16 : isa == avx2 ? 8 : 4;
727727

728728
auto set_or_check_wei_format = [&]() {
729729
using namespace format_tag;

src/cpu/x64/jit_uni_planar_convolution.cpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,28 +58,35 @@ void _jit_uni_planar_convolution_fwd_t<isa>::execute_forward(const exec_ctx_t &c
5858

5959
const auto &jcp = pd()->jcp_;
6060

61-
std::vector<int> od_indexes(jcp.od);
61+
std::vector<int> oh_indexes(jcp.oh);
6262

6363
int idx = 0;
64-
for (int i = 0; i < (jcp.dilate_d + 1); i++) {
65-
for (int ib = 0; ib < jcp.od; ib += (jcp.dilate_d + 1)) {
66-
if (ib + i >= jcp.od)
64+
for (int i = 0; i < (jcp.dilate_h + 1); i++) {
65+
for (int ib = 0; ib < jcp.oh; ib += (jcp.dilate_h + 1)) {
66+
if (ib + i >= jcp.oh)
6767
continue;
6868

69-
od_indexes[idx++] = ib + i;
70-
if (idx >= jcp.od)
69+
oh_indexes[idx++] = ib + i;
70+
if (idx >= jcp.oh)
7171
break;
7272
}
73-
if (idx >= jcp.od)
73+
if (idx >= jcp.oh)
7474
break;
7575
}
7676

7777
int threads_count = dnnl_get_max_threads();
78-
int odb_size = div_up(jcp.od, threads_count);
78+
int ohb_size = div_up(jcp.oh, threads_count);
7979

80-
auto kernel_params = [&](int n, int g, int icb, int oc, int od, int oh, int oh_blocks, int id, int wd, int kd_padding) {
80+
auto kernel_params = [&](int n, int g, int icb, int oc, int od, int oh, int oh_blocks) {
8181
auto par_conv = jit_conv_call_s();
8282

83+
const int dj = od * jcp.stride_d;
84+
const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
85+
const int d_b_overflow = nstl::max(jcp.id, dj + (jcp.kd - 1) * (jcp.dilate_d + 1) - jcp.f_pad + 1) - jcp.id;
86+
const int id = nstl::max(dj - jcp.f_pad + div_up(d_t_overflow, (jcp.dilate_d + 1)) * (jcp.dilate_d + 1), 0);
87+
const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
88+
const int kd_padding = jcp.kd - div_up(d_t_overflow, (jcp.dilate_d + 1)) - div_up(d_b_overflow, (jcp.dilate_d + 1));
89+
8390
const int hj = oh * jcp.stride_h;
8491
const int i_t_overflow = nstl::max(0, jcp.t_pad - hj);
8592
const int i_b_overflow = nstl::max(jcp.ih, hj + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad + 1) - jcp.ih;
@@ -126,27 +133,15 @@ void _jit_uni_planar_convolution_fwd_t<isa>::execute_forward(const exec_ctx_t &c
126133
icb_step = icb_step_rem;
127134

128135
for (int icb = icbb; icb < icbb + icb_step; ++icb) {
129-
for (int ohb = 0; ohb < (jcp.dilate_h + 1); ohb++) {
130-
for (int oh = ohb; oh < jcp.oh; oh += (jcp.dilate_h + 1)) {
131-
int od_idx_off = ithr * odb_size;
132-
for (int od_idx = 0; od_idx < odb_size; od_idx++) {
133-
if ((od_idx_off + od_idx) >= jcp.od || od_indexes[od_idx_off + od_idx] >= jcp.od)
136+
for (int odb = 0; odb < (jcp.dilate_d + 1); odb++) {
137+
for (int od = odb; od < jcp.od; od += (jcp.dilate_d + 1)) {
138+
int oh_idx_off = ithr * ohb_size;
139+
for (int oh_idx = 0; oh_idx < ohb_size; oh_idx++) {
140+
if ((oh_idx_off + oh_idx) >= jcp.oh || oh_indexes[oh_idx_off + oh_idx] >= jcp.oh)
134141
continue;
135-
int od = od_indexes[od_idx_off + od_idx];
136-
137-
const int dj = od * jcp.stride_d;
138-
const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
139-
const int d_b_overflow =
140-
nstl::max(jcp.id, dj + (jcp.kd - 1) * (jcp.dilate_d + 1) - jcp.f_pad + 1) -
141-
jcp.id;
142-
const int id = nstl::max(dj - jcp.f_pad +
143-
div_up(d_t_overflow, (jcp.dilate_d + 1)) * (jcp.dilate_d + 1),
144-
0);
145-
const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
146-
const int kd_padding = jcp.kd - div_up(d_t_overflow, (jcp.dilate_d + 1)) -
147-
div_up(d_b_overflow, (jcp.dilate_d + 1));
148-
149-
jit_conv_call_s par_conv = kernel_params(n, g, icb, oc, od, oh, 1, id, wd, kd_padding);
142+
int oh = oh_indexes[oh_idx_off + oh_idx];
143+
144+
jit_conv_call_s par_conv = kernel_params(n, g, icb, oc, od, oh, 1);
150145

151146
(*kernel_)(&par_conv);
152147
}

0 commit comments

Comments
 (0)