@@ -164,6 +164,8 @@ public:
164164 template <typename F>
165165 void post_forward_doit (F const & post_forward);
166166
167+ void prepare_openbc ();
168+
167169private:
168170
169171 static std::pair<Plan<T>,Plan<T>> make_c2c_plans (cMF& inout);
@@ -176,6 +178,8 @@ private:
176178 Plan<T> m_fft_bwd_y{};
177179 Plan<T> m_fft_fwd_z{};
178180 Plan<T> m_fft_bwd_z{};
181+ Plan<T> m_fft_fwd_x_half{};
182+ Plan<T> m_fft_bwd_x_half{};
179183
180184 // Comm meta-data. In the forward phase, we start with (x,y,z),
181185 // transpose to (y,x,z) and then (z,x,y). In the backward phase, we
@@ -394,6 +398,60 @@ R2C<T,D,S>::~R2C<T,D,S> ()
394398 m_fft_fwd_x.destroy ();
395399 m_fft_fwd_y.destroy ();
396400 m_fft_fwd_z.destroy ();
401+ if (m_fft_bwd_x_half.plan != m_fft_fwd_x_half.plan ) {
402+ m_fft_bwd_x_half.destroy ();
403+ }
404+ m_fft_fwd_x_half.destroy ();
405+ }
406+
407+ template <typename T, Direction D, DomainStrategy S>
408+ void R2C<T,D,S>::prepare_openbc ()
409+ {
410+ #if (AMREX_SPACEDIM == 3)
411+ if (m_slab_decomp) {
412+ auto * fab = detail::get_fab (m_rx);
413+ if (fab) {
414+ Box bottom_half = m_real_domain;
415+ bottom_half.growHi (2 ,-m_real_domain.length (2 )/2 );
416+ Box box = fab->box () & bottom_half;
417+ if (box.ok ()) {
418+ auto * pr = fab->dataPtr ();
419+ auto * pc = (typename Plan<T>::VendorComplex *)
420+ detail::get_fab (m_cx)->dataPtr ();
421+ #ifdef AMREX_USE_SYCL
422+ m_fft_fwd_x_half.template init_r2c <Direction::forward>
423+ (box, pr, pc, m_slab_decomp);
424+ m_fft_bwd_x_half = m_fft_fwd_x_half;
425+ #else
426+ if constexpr (D == Direction::both || D == Direction::forward) {
427+ m_fft_fwd_x_half.template init_r2c <Direction::forward>
428+ (box, pr, pc, m_slab_decomp);
429+ }
430+ if constexpr (D == Direction::both || D == Direction::backward) {
431+ m_fft_bwd_x_half.template init_r2c <Direction::backward>
432+ (box, pr, pc, m_slab_decomp);
433+ }
434+ #endif
435+ }
436+ }
437+ } // else todo
438+
439+ if (m_cmd_x2z && ! m_cmd_x2z_half) {
440+ Box bottom_half = m_spectral_domain_z;
441+ // Note that z-direction's index is 0 because we z is the
442+ // unit-stride direction here.
443+ bottom_half.growHi (0 ,-m_spectral_domain_z.length (0 )/2 );
444+ m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
445+ (m_cz, bottom_half, m_cx, IntVect (0 ), m_dtos_x2z);
446+ }
447+
448+ if (m_cmd_z2x && ! m_cmd_z2x_half) {
449+ Box bottom_half = m_spectral_domain_x;
450+ bottom_half.growHi (2 ,-m_spectral_domain_x.length (2 )/2 );
451+ m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
452+ (m_cx, bottom_half, m_cz, IntVect (0 ), m_dtos_z2x);
453+ }
454+ #endif
397455}
398456
399457template <typename T, Direction D, DomainStrategy S>
@@ -406,7 +464,8 @@ void R2C<T,D,S>::forward (MF const& inmf)
406464 if (&m_rx != &inmf) {
407465 m_rx.ParallelCopy (inmf, 0 , 0 , 1 );
408466 }
409- m_fft_fwd_x.template compute_r2c <Direction::forward>();
467+ auto & fft_x = m_openbc_half ? m_fft_fwd_x_half : m_fft_fwd_x;
468+ fft_x.template compute_r2c <Direction::forward>();
410469
411470 if ( m_cmd_x2y) {
412471 ParallelCopy (m_cy, m_cx, *m_cmd_x2y, 0 , 0 , 1 , m_dtos_x2y);
@@ -419,19 +478,16 @@ void R2C<T,D,S>::forward (MF const& inmf)
419478#if (AMREX_SPACEDIM == 3)
420479 else if ( m_cmd_x2z) {
421480 if (m_openbc_half) {
422- Box upper_half = m_spectral_domain_z;
423- // Note that z-direction's index is 0 because we z is the unit-stride direction here.
424- upper_half.growLo (0 ,-m_spectral_domain_z.length (0 )/2 );
425- if (! m_cmd_x2z_half) {
426- Box bottom_half = m_spectral_domain_z;
427- bottom_half.growHi (0 ,-m_spectral_domain_z.length (0 )/2 );
428- m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
429- (m_cz, bottom_half, m_cx, IntVect (0 ), m_dtos_x2z);
430- }
431481 NonLocalBC::ApplyDtosAndProjectionOnReciever packing
432482 {NonLocalBC::PackComponents{}, m_dtos_x2z};
433483 auto handler = ParallelCopy_nowait (m_cz, m_cx, *m_cmd_x2z_half, packing);
484+
485+ Box upper_half = m_spectral_domain_z;
486+ // Note that z-direction's index is 0 because we z is the
487+ // unit-stride direction here.
488+ upper_half.growLo (0 ,-m_spectral_domain_z.length (0 )/2 );
434489 m_cz.setVal (0 , upper_half, 0 , 1 );
490+
435491 ParallelCopy_finish (m_cz, std::move (handler), *m_cmd_x2z_half, packing);
436492 } else {
437493 ParallelCopy (m_cz, m_cx, *m_cmd_x2z, 0 , 0 , 1 , m_dtos_x2z);
@@ -459,22 +515,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
459515 }
460516#if (AMREX_SPACEDIM == 3)
461517 else if ( m_cmd_z2x) {
462- if (m_openbc_half) {
463- Box upper_half = m_spectral_domain_x;
464- upper_half.growLo (2 ,-m_spectral_domain_x.length (2 )/2 );
465- if (! m_cmd_z2x_half) {
466- Box bottom_half = m_spectral_domain_x;
467- bottom_half.growHi (2 ,-m_spectral_domain_x.length (2 )/2 );
468- m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
469- (m_cx, bottom_half, m_cz, IntVect (0 ), m_dtos_z2x);
470- }
471- NonLocalBC::ApplyDtosAndProjectionOnReciever packing
472- {NonLocalBC::PackComponents{}, m_dtos_z2x};
473- auto handler = ParallelCopy_nowait (m_cx, m_cz, *m_cmd_z2x_half, packing);
474- ParallelCopy_finish (m_cx, std::move (handler), *m_cmd_z2x_half, packing);
475- } else {
476- ParallelCopy (m_cx, m_cz, *m_cmd_z2x, 0 , 0 , 1 , m_dtos_z2x);
477- }
518+ auto const & cmd = m_openbc_half ? m_cmd_z2x_half : m_cmd_z2x;
519+ ParallelCopy (m_cx, m_cz, *cmd, 0 , 0 , 1 , m_dtos_z2x);
478520 }
479521#endif
480522
@@ -483,7 +525,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
483525 ParallelCopy (m_cx, m_cy, *m_cmd_y2x, 0 , 0 , 1 , m_dtos_y2x);
484526 }
485527
486- m_fft_bwd_x.template compute_r2c <Direction::backward>();
528+ auto & fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x;
529+ fft_x.template compute_r2c <Direction::backward>();
487530 outmf.ParallelCopy (m_rx, 0 , 0 , 1 , IntVect (0 ), ngout);
488531}
489532
0 commit comments