@@ -47,11 +47,19 @@ public:
4747 template <typename F>
4848 void forwardThenBackward (MF const & inmf, MF& outmf, F const & post_forward)
4949 {
50- this ->forward_doit (inmf);
50+ this ->forward (inmf);
5151 this ->post_forward_doit (post_forward);
52- this ->backward_doit (outmf);
52+ this ->backward (outmf);
5353 }
5454
55+ void forward (MF const & inmf, Scaling scaling = Scaling::none);
56+ void forward (MF const & inmf, cMF& outmf, Scaling scaling = Scaling::none);
57+
58+ void backward (MF& outmf, Scaling scaling = Scaling::none);
59+ void backward (cMF const & inmf, MF& outmf, Scaling scaling = Scaling::none);
60+
61+ std::pair<cMF*,IntVect> getSpectralData ();
62+
5563 struct Swap01
5664 {
5765 [[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator () (Dim3 i) const noexcept
@@ -153,9 +161,6 @@ private:
153161 }
154162 }
155163
156- void forward_doit (MF const & inmf, Scaling scaling = Scaling::none);
157- void backward_doit (MF& outmf, Scaling scaling = Scaling::none);
158-
159164 static void exec_r2c (Plan plan, MF& in, cMF& out);
160165 static void exec_c2r (Plan plan, cMF& in, MF& out);
161166 template <Direction direction>
@@ -175,10 +180,10 @@ private:
175180 // Comm meta-data. In the forward phase, we start with (x,y,z),
176181 // transpose to (y,x,z) and then (z,x,y). In the backward phase, we
177182 // perform inverse transpose.
178- std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y;
179- std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x;
180- std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z;
181- std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y;
183+ std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y; // (x,y,z) -> (y,x,z)
184+ std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x; // (y,x,z) -> (x,y,z)
185+ std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z; // (y,x,z) -> (z,x,y)
186+ std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y; // (z,x,y) -> (y,x,z)
182187 Swap01 m_dtos_x2y{};
183188 Swap01 m_dtos_y2x{};
184189 Swap02 m_dtos_y2z{};
@@ -232,12 +237,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
232237 int nprocs = ParallelDescriptor::NProcs ();
233238
234239 auto bax = amrex::decompose (m_real_domain, nprocs, {AMREX_D_DECL (false ,true ,true )});
235- DistributionMapping dmx;
236- {
237- Vector<int > pm (bax.size ());
238- std::iota (pm.begin (), pm.end (), 0 );
239- dmx.define (std::move (pm));
240- }
240+ DistributionMapping dmx = detail::make_iota_distromap (bax.size ());
241241 m_rx.define (bax, dmx, 1 , 0 );
242242
243243 {
@@ -346,9 +346,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
346346 if (cbay.size () == dmx.size ()) {
347347 cdmy = dmx;
348348 } else {
349- Vector<int > pm (cbay.size ());
350- std::iota (pm.begin (), pm.end (), 0 );
351- cdmy.define (std::move (pm));
349+ cdmy = detail::make_iota_distromap (cbay.size ());
352350 }
353351 m_cy.define (cbay, cdmy, 1 , 0 );
354352
@@ -365,7 +363,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
365363
366364#if (AMREX_SPACEDIM == 3)
367365 if (m_real_domain.length (1 ) > 1 &&
368- (! m_info.batch_mode || m_real_domain.length (2 ) > 1 ))
366+ (! m_info.batch_mode && m_real_domain.length (2 ) > 1 ))
369367 {
370368 auto cbaz = amrex::decompose (m_spectral_domain_z, nprocs, {false ,true ,true });
371369 DistributionMapping cdmz;
@@ -374,9 +372,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
374372 } else if (cbaz.size () == cdmy.size ()) {
375373 cdmz = cdmy;
376374 } else {
377- Vector<int > pm (cbaz.size ());
378- std::iota (pm.begin (), pm.end (), 0 );
379- cdmz.define (std::move (pm));
375+ cdmz = detail::make_iota_distromap (cbaz.size ());
380376 }
381377 m_cz.define (cbaz, cdmz, 1 , 0 );
382378
@@ -563,8 +559,10 @@ void R2C<T>::exec_c2c (Plan2 plan, cMF& inout)
563559}
564560
565561template <typename T>
566- void R2C<T>::forward_doit (MF const & inmf, Scaling /* scaling*/ )
562+ void R2C<T>::forward (MF const & inmf, Scaling scaling)
567563{
564+ AMREX_ALWAYS_ASSERT (scaling == Scaling::none); // xxxxx TODO
565+
568566 m_rx.ParallelCopy (inmf, 0 , 0 , 1 );
569567 exec_r2c (m_fft_fwd_x, m_rx, m_cx);
570568
@@ -580,8 +578,10 @@ void R2C<T>::forward_doit (MF const& inmf, Scaling /*scaling*/)
580578}
581579
582580template <typename T>
583- void R2C<T>::backward_doit (MF& outmf, Scaling /* scaling*/ )
581+ void R2C<T>::backward (MF& outmf, Scaling scaling)
584582{
583+ AMREX_ALWAYS_ASSERT (scaling == Scaling::none); // xxxxx TODO
584+
585585 exec_c2c<Direction::backward>(m_fft_bwd_z, m_cz);
586586 if ( m_cmd_z2y) {
587587 ParallelCopy (m_cy, m_cz, *m_cmd_z2y, 0 , 0 , 1 , m_dtos_z2y);
@@ -716,6 +716,51 @@ void R2C<T>::post_forward_doit (F const& post_forward)
716716 }
717717}
718718
719+ template <typename T>
720+ std::pair<typename R2C<T>::cMF *, IntVect>
721+ R2C<T>::getSpectralData ()
722+ {
723+ if (!m_cz.empty ()) {
724+ return std::make_pair (&m_cz, IntVect{AMREX_D_DECL (2 ,0 ,1 )});
725+ } else if (!m_cy.empty ()) {
726+ return std::make_pair (&m_cy, IntVect{AMREX_D_DECL (1 ,0 ,2 )});
727+ } else {
728+ return std::make_pair (&m_cx, IntVect{AMREX_D_DECL (0 ,1 ,2 )});
729+ }
730+ }
731+
732+ template <typename T>
733+ void R2C<T>::forward (MF const & inmf, cMF& outmf, Scaling scaling)
734+ {
735+ forward (inmf);
736+ if (!m_cz.empty ()) { // m_cz's ordering is z,x,y
737+ amrex::Abort (" xxxxx todo, forward m_cz" );
738+ } else if (!m_cy.empty ()) { // m_cy's order (y,x,z) -> (x,y,z)
739+ MultiBlockCommMetaData cmd
740+ (outmf.boxArray (), outmf.DistributionMap (), m_spectral_domain_x,
741+ m_cy.boxArray (), m_cy.DistributionMap (), IntVect (0 ), m_dtos_y2x);
742+ ParallelCopy (outmf, m_cy, cmd, 0 , 0 , 1 , m_dtos_y2x);
743+ } else {
744+ outmf.ParallelCopy (m_cx, 0 , 0 , 1 );
745+ }
746+ }
747+
748+ template <typename T>
749+ void R2C<T>::backward (cMF const & inmf, MF& outmf, Scaling scaling)
750+ {
751+ if (!m_cz.empty ()) { // m_cz's ordering is z,x,y
752+ amrex::Abort (" xxxxx todo, backward m_cz" );
753+ } else if (!m_cy.empty ()) { // (x,y,z) -> m_cy's ordering (y,x,z)
754+ MultiBlockCommMetaData cmd
755+ (m_cy.boxArray (), m_cy.DistributionMap (), m_spectral_domain_y,
756+ inmf.boxArray (), inmf.DistributionMap (), IntVect (0 ), m_dtos_x2y);
757+ ParallelCopy (m_cy, inmf, cmd, 0 , 0 , 1 , m_dtos_x2y);
758+ } else {
759+ m_cx.ParallelCopy (inmf, 0 , 0 , 1 );
760+ }
761+ backward (outmf);
762+ }
763+
719764}
720765
721766#endif
0 commit comments