@@ -193,6 +193,8 @@ public:
193193 void solve (MF& soln, MF const & rhs, Vector<T> const & dz);
194194 void solve (MF& soln, MF const & rhs, Gpu::DeviceVector<T> const & dz);
195195
196+ void solve_2d (MF& a_soln, MF const & a_rhs);
197+
196198 template <typename TRIA, typename TRIC>
197199 void solve (MF& a_soln, MF const & a_rhs, TRIA const & tria, TRIC const & tric);
198200
@@ -350,6 +352,14 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
350352#endif /* AMREX_SPACEDIM == 3 */
351353
352354namespace fft_poisson_detail {
355+ template <typename T>
356+ struct Tri_Zero {
357+ static constexpr T operator () (int , int , int )
358+ {
359+ return 0 ;
360+ }
361+ };
362+
353363 template <typename T>
354364 struct Tri_Uniform {
355365 [[nodiscard]] AMREX_GPU_DEVICE AMREX_FORCE_INLINE
@@ -480,6 +490,12 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
480490 fft_poisson_detail::TriC<T>{pdz,int (dz.size ())});
481491}
482492
493+ template <typename MF>
494+ void PoissonHybrid<MF>::solve_2d (MF& soln, MF const & rhs)
495+ {
496+ solve (soln, rhs, fft_poisson_detail::Tri_Zero<T>{}, fft_poisson_detail::Tri_Zero<T>{});
497+ }
498+
483499template <typename MF>
484500template <typename TRIA, typename TRIC>
485501void PoissonHybrid<MF>::solve (MF& a_soln, MF const & a_rhs, TRIA const & tria,
@@ -565,150 +581,176 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
565581 }
566582 }
567583
568- bool zlo_neumann = m_bc[2 ].first == Boundary::even;
569- bool zhi_neumann = m_bc[2 ].second == Boundary::even;
570- bool is_singular = (offset[0 ] == T (0 )) && (offset[1 ] == T (0 ))
571- && zlo_neumann && zhi_neumann;
584+ if constexpr (std::is_same_v<TRIA,fft_poisson_detail::Tri_Zero<T>> &&
585+ std::is_same_v<TRIC,fft_poisson_detail::Tri_Zero<T>>) {
586+ #if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
587+ #pragma omp parallel
588+ #endif
589+ for (MFIter mfi (spmf,TilingIfNotGPU ()); mfi.isValid (); ++mfi)
590+ {
591+ auto const & spectral = spmf.array (mfi);
592+ auto const & box = mfi.validbox ();
593+ amrex::ParallelFor (box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
594+ {
595+ T a = facx*(i+offset[0 ]);
596+ T b = facy*(j+offset[1 ]);
597+ T k2 = dxfac * (std::cos (a)-T (1 ))
598+ + dyfac * (std::cos (b)-T (1 ));
599+ if (k2 != T (0 )) {
600+ spectral (i,j,k) /= k2;
601+ }
602+ spectral (i,j,k) *= scale;
603+ });
604+ }
605+ } else {
606+ bool zlo_neumann = m_bc[2 ].first == Boundary::even;
607+ bool zhi_neumann = m_bc[2 ].second == Boundary::even;
608+ bool is_singular = (offset[0 ] == T (0 )) && (offset[1 ] == T (0 ))
609+ && zlo_neumann && zhi_neumann;
572610
573- auto nz = m_geom.Domain ().length (2 );
611+ auto nz = m_geom.Domain ().length (2 );
574612
575- for (MFIter mfi (spmf); mfi.isValid (); ++mfi)
576- {
577- auto const & spectral = spmf.array (mfi);
578- auto const & box = mfi.validbox ();
579- auto const & xybox = amrex::makeSlab (box, 2 , 0 );
613+ #if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
614+ #pragma omp parallel
615+ #endif
616+ for (MFIter mfi (spmf); mfi.isValid (); ++mfi)
617+ {
618+ auto const & spectral = spmf.array (mfi);
619+ auto const & box = mfi.validbox ();
620+ auto const & xybox = amrex::makeSlab (box, 2 , 0 );
580621
581622#ifdef AMREX_USE_GPU
582- // xxxxx TODO: We need to explore how to optimize this
583- // function. Maybe we can use cusparse. Maybe we should make
584- // z-direction to be the unit stride direction.
623+ // xxxxx TODO: We need to explore how to optimize this
624+ // function. Maybe we can use cusparse. Maybe we should make
625+ // z-direction to be the unit stride direction.
585626
586- FArrayBox tridiag_workspace (box,4 );
587- auto const & ald = tridiag_workspace.array (0 );
588- auto const & bd = tridiag_workspace.array (1 );
589- auto const & cud = tridiag_workspace.array (2 );
590- auto const & scratch = tridiag_workspace.array (3 );
627+ FArrayBox tridiag_workspace (box,4 );
628+ auto const & ald = tridiag_workspace.array (0 );
629+ auto const & bd = tridiag_workspace.array (1 );
630+ auto const & cud = tridiag_workspace.array (2 );
631+ auto const & scratch = tridiag_workspace.array (3 );
591632
592- amrex::ParallelFor (xybox, [=] AMREX_GPU_DEVICE (int i, int j, int )
593- {
594- T a = facx*(i+offset[0 ]);
595- T b = facy*(j+offset[1 ]);
596- T k2 = dxfac * (std::cos (a)-T (1 ))
597- + dyfac * (std::cos (b)-T (1 ));
598-
599- // Tridiagonal solve
600- for (int k=0 ; k < nz; k++) {
601- if (k==0 ) {
602- ald (i,j,k) = T (0 .);
603- cud (i,j,k) = tric (i,j,k);
604- if (zlo_neumann) {
605- bd (i,j,k) = k2 - cud (i,j,k);
606- } else {
607- bd (i,j,k) = k2 - cud (i,j,k) - T (2.0 )*tria (i,j,k);
608- }
609- } else if (k == nz-1 ) {
610- ald (i,j,k) = tria (i,j,k);
611- cud (i,j,k) = T (0 .);
612- if (zhi_neumann) {
613- bd (i,j,k) = k2 - ald (i,j,k);
614- if (i == 0 && j == 0 && is_singular) {
615- bd (i,j,k) *= T (2.0 );
633+ amrex::ParallelFor (xybox, [=] AMREX_GPU_DEVICE (int i, int j, int )
634+ {
635+ T a = facx*(i+offset[0 ]);
636+ T b = facy*(j+offset[1 ]);
637+ T k2 = dxfac * (std::cos (a)-T (1 ))
638+ + dyfac * (std::cos (b)-T (1 ));
639+
640+ // Tridiagonal solve
641+ for (int k=0 ; k < nz; k++) {
642+ if (k==0 ) {
643+ ald (i,j,k) = T (0 .);
644+ cud (i,j,k) = tric (i,j,k);
645+ if (zlo_neumann) {
646+ bd (i,j,k) = k2 - cud (i,j,k);
647+ } else {
648+ bd (i,j,k) = k2 - cud (i,j,k) - T (2.0 )*tria (i,j,k);
649+ }
650+ } else if (k == nz-1 ) {
651+ ald (i,j,k) = tria (i,j,k);
652+ cud (i,j,k) = T (0 .);
653+ if (zhi_neumann) {
654+ bd (i,j,k) = k2 - ald (i,j,k);
655+ if (i == 0 && j == 0 && is_singular) {
656+ bd (i,j,k) *= T (2.0 );
657+ }
658+ } else {
659+ bd (i,j,k) = k2 - ald (i,j,k) - T (2.0 )*tric (i,j,k);
616660 }
617661 } else {
618- bd (i,j,k) = k2 - ald (i,j,k) - T (2.0 )*tric (i,j,k);
662+ ald (i,j,k) = tria (i,j,k);
663+ cud (i,j,k) = tric (i,j,k);
664+ bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
619665 }
620- } else {
621- ald (i,j,k) = tria (i,j,k);
622- cud (i,j,k) = tric (i,j,k);
623- bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
624666 }
625- }
626667
627- scratch (i,j,0 ) = cud (i,j,0 )/bd (i,j,0 );
628- spectral (i,j,0 ) = spectral (i,j,0 )/bd (i,j,0 );
668+ scratch (i,j,0 ) = cud (i,j,0 )/bd (i,j,0 );
669+ spectral (i,j,0 ) = spectral (i,j,0 )/bd (i,j,0 );
629670
630- for (int k = 1 ; k < nz; k++) {
631- if (k < nz-1 ) {
632- scratch (i,j,k) = cud (i,j,k) / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
671+ for (int k = 1 ; k < nz; k++) {
672+ if (k < nz-1 ) {
673+ scratch (i,j,k) = cud (i,j,k) / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
674+ }
675+ spectral (i,j,k) = (spectral (i,j,k) - ald (i,j,k) * spectral (i,j,k - 1 ))
676+ / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
633677 }
634- spectral (i,j,k) = (spectral (i,j,k) - ald (i,j,k) * spectral (i,j,k - 1 ))
635- / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
636- }
637678
638- for (int k = nz - 2 ; k >= 0 ; k--) {
639- spectral (i,j,k) -= scratch (i,j,k) * spectral (i,j,k + 1 );
640- }
679+ for (int k = nz - 2 ; k >= 0 ; k--) {
680+ spectral (i,j,k) -= scratch (i,j,k) * spectral (i,j,k + 1 );
681+ }
641682
642- for (int k = 0 ; k < nz; ++k) {
643- spectral (i,j,k) *= scale;
644- }
645- });
646- Gpu::streamSynchronize ();
683+ for (int k = 0 ; k < nz; ++k) {
684+ spectral (i,j,k) *= scale;
685+ }
686+ });
687+ Gpu::streamSynchronize ();
647688
648689#else
649690
650- Gpu::DeviceVector<T> ald (nz);
651- Gpu::DeviceVector<T> bd (nz);
652- Gpu::DeviceVector<T> cud (nz);
653- Gpu::DeviceVector<T> scratch (nz);
691+ Gpu::DeviceVector<T> ald (nz);
692+ Gpu::DeviceVector<T> bd (nz);
693+ Gpu::DeviceVector<T> cud (nz);
694+ Gpu::DeviceVector<T> scratch (nz);
654695
655- amrex::LoopOnCpu (xybox, [&] (int i, int j, int )
656- {
657- T a = facx*(i+offset[0 ]);
658- T b = facy*(j+offset[1 ]);
659- T k2 = dxfac * (std::cos (a)-T (1 ))
660- + dyfac * (std::cos (b)-T (1 ));
661-
662- // Tridiagonal solve
663- for (int k=0 ; k < nz; k++) {
664- if (k==0 ) {
665- ald[k] = T (0 .);
666- cud[k] = tric (i,j,k);
667- if (zlo_neumann) {
668- bd[k] = k2 - cud[k];
669- } else {
670- bd[k] = k2 - cud[k] - T (2.0 )*tria (i,j,k);
671- }
672- } else if (k == nz-1 ) {
673- ald[k] = tria (i,j,k);
674- cud[k] = T (0 .);
675- if (zhi_neumann) {
676- bd[k] = k2 - ald[k];
677- if (i == 0 && j == 0 && is_singular) {
678- bd[k] *= T (2.0 );
696+ amrex::LoopOnCpu (xybox, [&] (int i, int j, int )
697+ {
698+ T a = facx*(i+offset[0 ]);
699+ T b = facy*(j+offset[1 ]);
700+ T k2 = dxfac * (std::cos (a)-T (1 ))
701+ + dyfac * (std::cos (b)-T (1 ));
702+
703+ // Tridiagonal solve
704+ for (int k=0 ; k < nz; k++) {
705+ if (k==0 ) {
706+ ald[k] = T (0 .);
707+ cud[k] = tric (i,j,k);
708+ if (zlo_neumann) {
709+ bd[k] = k2 - cud[k];
710+ } else {
711+ bd[k] = k2 - cud[k] - T (2.0 )*tria (i,j,k);
712+ }
713+ } else if (k == nz-1 ) {
714+ ald[k] = tria (i,j,k);
715+ cud[k] = T (0 .);
716+ if (zhi_neumann) {
717+ bd[k] = k2 - ald[k];
718+ if (i == 0 && j == 0 && is_singular) {
719+ bd[k] *= T (2.0 );
720+ }
721+ } else {
722+ bd[k] = k2 - ald[k] - T (2.0 )*tric (i,j,k);
679723 }
680724 } else {
681- bd[k] = k2 - ald[k] - T (2.0 )*tric (i,j,k);
725+ ald[k] = tria (i,j,k);
726+ cud[k] = tric (i,j,k);
727+ bd[k] = k2 -ald[k]-cud[k];
682728 }
683- } else {
684- ald[k] = tria (i,j,k);
685- cud[k] = tric (i,j,k);
686- bd[k] = k2 -ald[k]-cud[k];
687729 }
688- }
689730
690- scratch[0 ] = cud[0 ]/bd[0 ];
691- spectral (i,j,0 ) = spectral (i,j,0 )/bd[0 ];
731+ scratch[0 ] = cud[0 ]/bd[0 ];
732+ spectral (i,j,0 ) = spectral (i,j,0 )/bd[0 ];
692733
693- for (int k = 1 ; k < nz; k++) {
694- if (k < nz-1 ) {
695- scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1 ]);
734+ for (int k = 1 ; k < nz; k++) {
735+ if (k < nz-1 ) {
736+ scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1 ]);
737+ }
738+ spectral (i,j,k) = (spectral (i,j,k) - ald[k] * spectral (i,j,k - 1 ))
739+ / (bd[k] - ald[k] * scratch[k-1 ]);
696740 }
697- spectral (i,j,k) = (spectral (i,j,k) - ald[k] * spectral (i,j,k - 1 ))
698- / (bd[k] - ald[k] * scratch[k-1 ]);
699- }
700741
701- for (int k = nz - 2 ; k >= 0 ; k--) {
702- spectral (i,j,k) -= scratch[k] * spectral (i,j,k + 1 );
703- }
742+ for (int k = nz - 2 ; k >= 0 ; k--) {
743+ spectral (i,j,k) -= scratch[k] * spectral (i,j,k + 1 );
744+ }
704745
705- for (int k = 0 ; k < nz; ++k) {
706- spectral (i,j,k) *= scale;
707- }
708- });
746+ for (int k = 0 ; k < nz; ++k) {
747+ spectral (i,j,k) *= scale;
748+ }
749+ });
709750#endif
710- }
751+ }
711752#endif
753+ }
712754}
713755
714756namespace detail {
0 commit comments