@@ -58,28 +58,35 @@ void _jit_uni_planar_convolution_fwd_t<isa>::execute_forward(const exec_ctx_t &c
5858
5959 const auto &jcp = pd ()->jcp_ ;
6060
61- std::vector<int > od_indexes (jcp.od );
61+ std::vector<int > oh_indexes (jcp.oh );
6262
6363 int idx = 0 ;
64- for (int i = 0 ; i < (jcp.dilate_d + 1 ); i++) {
65- for (int ib = 0 ; ib < jcp.od ; ib += (jcp.dilate_d + 1 )) {
66- if (ib + i >= jcp.od )
64+ for (int i = 0 ; i < (jcp.dilate_h + 1 ); i++) {
65+ for (int ib = 0 ; ib < jcp.oh ; ib += (jcp.dilate_h + 1 )) {
66+ if (ib + i >= jcp.oh )
6767 continue ;
6868
69- od_indexes [idx++] = ib + i;
70- if (idx >= jcp.od )
69+ oh_indexes [idx++] = ib + i;
70+ if (idx >= jcp.oh )
7171 break ;
7272 }
73- if (idx >= jcp.od )
73+ if (idx >= jcp.oh )
7474 break ;
7575 }
7676
7777 int threads_count = dnnl_get_max_threads ();
78- int odb_size = div_up (jcp.od , threads_count);
78+ int ohb_size = div_up (jcp.oh , threads_count);
7979
80- auto kernel_params = [&](int n, int g, int icb, int oc, int od, int oh, int oh_blocks, int id, int wd, int kd_padding ) {
80+ auto kernel_params = [&](int n, int g, int icb, int oc, int od, int oh, int oh_blocks) {
8181 auto par_conv = jit_conv_call_s ();
8282
83+ const int dj = od * jcp.stride_d ;
84+ const int d_t_overflow = nstl::max (0 , jcp.f_pad - dj);
85+ const int d_b_overflow = nstl::max (jcp.id , dj + (jcp.kd - 1 ) * (jcp.dilate_d + 1 ) - jcp.f_pad + 1 ) - jcp.id ;
86+ const int id = nstl::max (dj - jcp.f_pad + div_up (d_t_overflow, (jcp.dilate_d + 1 )) * (jcp.dilate_d + 1 ), 0 );
87+ const int wd = div_up (d_t_overflow, (jcp.dilate_d + 1 ));
88+ const int kd_padding = jcp.kd - div_up (d_t_overflow, (jcp.dilate_d + 1 )) - div_up (d_b_overflow, (jcp.dilate_d + 1 ));
89+
8390 const int hj = oh * jcp.stride_h ;
8491 const int i_t_overflow = nstl::max (0 , jcp.t_pad - hj);
8592 const int i_b_overflow = nstl::max (jcp.ih , hj + (jcp.kh - 1 ) * (jcp.dilate_h + 1 ) - jcp.t_pad + 1 ) - jcp.ih ;
@@ -126,27 +133,15 @@ void _jit_uni_planar_convolution_fwd_t<isa>::execute_forward(const exec_ctx_t &c
126133 icb_step = icb_step_rem;
127134
128135 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
129- for (int ohb = 0 ; ohb < (jcp.dilate_h + 1 ); ohb ++) {
130- for (int oh = ohb; oh < jcp.oh ; oh += (jcp.dilate_h + 1 )) {
131- int od_idx_off = ithr * odb_size ;
132- for (int od_idx = 0 ; od_idx < odb_size; od_idx ++) {
133- if ((od_idx_off + od_idx ) >= jcp.od || od_indexes[od_idx_off + od_idx ] >= jcp.od )
136+ for (int odb = 0 ; odb < (jcp.dilate_d + 1 ); odb ++) {
137+ for (int od = odb; od < jcp.od ; od += (jcp.dilate_d + 1 )) {
138+ int oh_idx_off = ithr * ohb_size ;
139+ for (int oh_idx = 0 ; oh_idx < ohb_size; oh_idx ++) {
140+ if ((oh_idx_off + oh_idx ) >= jcp.oh || oh_indexes[oh_idx_off + oh_idx ] >= jcp.oh )
134141 continue ;
135- int od = od_indexes[od_idx_off + od_idx];
136-
137- const int dj = od * jcp.stride_d ;
138- const int d_t_overflow = nstl::max (0 , jcp.f_pad - dj);
139- const int d_b_overflow =
140- nstl::max (jcp.id , dj + (jcp.kd - 1 ) * (jcp.dilate_d + 1 ) - jcp.f_pad + 1 ) -
141- jcp.id ;
142- const int id = nstl::max (dj - jcp.f_pad +
143- div_up (d_t_overflow, (jcp.dilate_d + 1 )) * (jcp.dilate_d + 1 ),
144- 0 );
145- const int wd = div_up (d_t_overflow, (jcp.dilate_d + 1 ));
146- const int kd_padding = jcp.kd - div_up (d_t_overflow, (jcp.dilate_d + 1 )) -
147- div_up (d_b_overflow, (jcp.dilate_d + 1 ));
148-
149- jit_conv_call_s par_conv = kernel_params (n, g, icb, oc, od, oh, 1 , id, wd, kd_padding);
142+ int oh = oh_indexes[oh_idx_off + oh_idx];
143+
144+ jit_conv_call_s par_conv = kernel_params (n, g, icb, oc, od, oh, 1 );
150145
151146 (*kernel_)(&par_conv);
152147 }
0 commit comments