@@ -193,6 +193,8 @@ public:
193193 void solve (MF& soln, MF const & rhs, Vector<T> const & dz);
194194 void solve (MF& soln, MF const & rhs, Gpu::DeviceVector<T> const & dz);
195195
196+ void solve_2d (MF& a_soln, MF const & a_rhs);
197+
196198 template <typename TRIA, typename TRIC>
197199 void solve (MF& a_soln, MF const & a_rhs, TRIA const & tria, TRIC const & tric);
198200
@@ -350,6 +352,14 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
350352#endif /* AMREX_SPACEDIM == 3 */
351353
352354namespace fft_poisson_detail {
355+ template <typename T>
356+ struct Tri_Zero {
357+ [[nodiscard]] constexpr T operator () (int , int , int ) const
358+ {
359+ return 0 ;
360+ }
361+ };
362+
353363 template <typename T>
354364 struct Tri_Uniform {
355365 [[nodiscard]] AMREX_GPU_DEVICE AMREX_FORCE_INLINE
@@ -480,6 +490,12 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
480490 fft_poisson_detail::TriC<T>{pdz,int (dz.size ())});
481491}
482492
493+ template <typename MF>
494+ void PoissonHybrid<MF>::solve_2d (MF& soln, MF const & rhs)
495+ {
496+ solve (soln, rhs, fft_poisson_detail::Tri_Zero<T>{}, fft_poisson_detail::Tri_Zero<T>{});
497+ }
498+
483499template <typename MF>
484500template <typename TRIA, typename TRIC>
485501void PoissonHybrid<MF>::solve (MF& a_soln, MF const & a_rhs, TRIA const & tria,
@@ -565,148 +581,179 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
565581 }
566582 }
567583
568- bool zlo_neumann = m_bc[2 ].first == Boundary::even;
569- bool zhi_neumann = m_bc[2 ].second == Boundary::even;
570- bool is_singular = (offset[0 ] == T (0 )) && (offset[1 ] == T (0 ))
571- && zlo_neumann && zhi_neumann;
584+ if
585+ #ifndef _WIN32
586+ constexpr
587+ #endif
588+ (std::is_same_v<TRIA,fft_poisson_detail::Tri_Zero<T>> &&
589+ std::is_same_v<TRIC,fft_poisson_detail::Tri_Zero<T>>) {
590+ amrex::ignore_unused (tria,tric);
591+ #if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
592+ #pragma omp parallel
593+ #endif
594+ for (MFIter mfi (spmf,TilingIfNotGPU ()); mfi.isValid (); ++mfi)
595+ {
596+ auto const & spectral = spmf.array (mfi);
597+ auto const & box = mfi.validbox ();
598+ amrex::ParallelFor (box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
599+ {
600+ T a = facx*(i+offset[0 ]);
601+ T b = facy*(j+offset[1 ]);
602+ T k2 = dxfac * (std::cos (a)-T (1 ))
603+ + dyfac * (std::cos (b)-T (1 ));
604+ if (k2 != T (0 )) {
605+ spectral (i,j,k) /= k2;
606+ }
607+ spectral (i,j,k) *= scale;
608+ });
609+ }
610+ } else {
611+ bool zlo_neumann = m_bc[2 ].first == Boundary::even;
612+ bool zhi_neumann = m_bc[2 ].second == Boundary::even;
613+ bool is_singular = (offset[0 ] == T (0 )) && (offset[1 ] == T (0 ))
614+ && zlo_neumann && zhi_neumann;
572615
573- auto nz = m_geom.Domain ().length (2 );
616+ auto nz = m_geom.Domain ().length (2 );
574617
575- for (MFIter mfi (spmf); mfi.isValid (); ++mfi)
576- {
577- auto const & spectral = spmf.array (mfi);
578- auto const & box = mfi.validbox ();
579- auto const & xybox = amrex::makeSlab (box, 2 , 0 );
618+ #if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
619+ #pragma omp parallel
620+ #endif
621+ for (MFIter mfi (spmf); mfi.isValid (); ++mfi)
622+ {
623+ auto const & spectral = spmf.array (mfi);
624+ auto const & box = mfi.validbox ();
625+ auto const & xybox = amrex::makeSlab (box, 2 , 0 );
580626
581627#ifdef AMREX_USE_GPU
582- // xxxxx TODO: We need to explore how to optimize this
583- // function. Maybe we can use cusparse. Maybe we should make
584- // z-direction to be the unit stride direction.
628+ // xxxxx TODO: We need to explore how to optimize this
629+ // function. Maybe we can use cusparse. Maybe we should make
630+ // z-direction to be the unit stride direction.
585631
586- FArrayBox tridiag_workspace (box,4 );
587- auto const & ald = tridiag_workspace.array (0 );
588- auto const & bd = tridiag_workspace.array (1 );
589- auto const & cud = tridiag_workspace.array (2 );
590- auto const & scratch = tridiag_workspace.array (3 );
632+ FArrayBox tridiag_workspace (box,4 );
633+ auto const & ald = tridiag_workspace.array (0 );
634+ auto const & bd = tridiag_workspace.array (1 );
635+ auto const & cud = tridiag_workspace.array (2 );
636+ auto const & scratch = tridiag_workspace.array (3 );
591637
592- amrex::ParallelFor (xybox, [=] AMREX_GPU_DEVICE (int i, int j, int )
593- {
594- T a = facx*(i+offset[0 ]);
595- T b = facy*(j+offset[1 ]);
596- T k2 = dxfac * (std::cos (a)-T (1 ))
597- + dyfac * (std::cos (b)-T (1 ));
598-
599- // Tridiagonal solve
600- for (int k=0 ; k < nz; k++) {
601- if (k==0 ) {
602- ald (i,j,k) = T (0 .);
603- cud (i,j,k) = tric (i,j,k);
604- if (zlo_neumann) {
605- bd (i,j,k) = k2 - cud (i,j,k);
606- } else {
607- bd (i,j,k) = k2 - cud (i,j,k) - T (2.0 )*tria (i,j,k);
608- }
609- } else if (k == nz-1 ) {
610- ald (i,j,k) = tria (i,j,k);
611- cud (i,j,k) = T (0 .);
612- if (zhi_neumann) {
613- bd (i,j,k) = k2 - ald (i,j,k);
614- if (i == 0 && j == 0 && is_singular) {
615- bd (i,j,k) *= T (2.0 );
638+ amrex::ParallelFor (xybox, [=] AMREX_GPU_DEVICE (int i, int j, int )
639+ {
640+ T a = facx*(i+offset[0 ]);
641+ T b = facy*(j+offset[1 ]);
642+ T k2 = dxfac * (std::cos (a)-T (1 ))
643+ + dyfac * (std::cos (b)-T (1 ));
644+
645+ // Tridiagonal solve
646+ for (int k=0 ; k < nz; k++) {
647+ if (k==0 ) {
648+ ald (i,j,k) = T (0 .);
649+ cud (i,j,k) = tric (i,j,k);
650+ if (zlo_neumann) {
651+ bd (i,j,k) = k2 - cud (i,j,k);
652+ } else {
653+ bd (i,j,k) = k2 - cud (i,j,k) - T (2.0 )*tria (i,j,k);
654+ }
655+ } else if (k == nz-1 ) {
656+ ald (i,j,k) = tria (i,j,k);
657+ cud (i,j,k) = T (0 .);
658+ if (zhi_neumann) {
659+ bd (i,j,k) = k2 - ald (i,j,k);
660+ if (i == 0 && j == 0 && is_singular) {
661+ bd (i,j,k) *= T (2.0 );
662+ }
663+ } else {
664+ bd (i,j,k) = k2 - ald (i,j,k) - T (2.0 )*tric (i,j,k);
616665 }
617666 } else {
618- bd (i,j,k) = k2 - ald (i,j,k) - T (2.0 )*tric (i,j,k);
667+ ald (i,j,k) = tria (i,j,k);
668+ cud (i,j,k) = tric (i,j,k);
669+ bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
619670 }
620- } else {
621- ald (i,j,k) = tria (i,j,k);
622- cud (i,j,k) = tric (i,j,k);
623- bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
624671 }
625- }
626672
627- scratch (i,j,0 ) = cud (i,j,0 )/bd (i,j,0 );
628- spectral (i,j,0 ) = spectral (i,j,0 )/bd (i,j,0 );
673+ scratch (i,j,0 ) = cud (i,j,0 )/bd (i,j,0 );
674+ spectral (i,j,0 ) = spectral (i,j,0 )/bd (i,j,0 );
629675
630- for (int k = 1 ; k < nz; k++) {
631- if (k < nz-1 ) {
632- scratch (i,j,k) = cud (i,j,k) / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
676+ for (int k = 1 ; k < nz; k++) {
677+ if (k < nz-1 ) {
678+ scratch (i,j,k) = cud (i,j,k) / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
679+ }
680+ spectral (i,j,k) = (spectral (i,j,k) - ald (i,j,k) * spectral (i,j,k - 1 ))
681+ / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
633682 }
634- spectral (i,j,k) = (spectral (i,j,k) - ald (i,j,k) * spectral (i,j,k - 1 ))
635- / (bd (i,j,k) - ald (i,j,k) * scratch (i,j,k-1 ));
636- }
637683
638- for (int k = nz - 2 ; k >= 0 ; k--) {
639- spectral (i,j,k) -= scratch (i,j,k) * spectral (i,j,k + 1 );
640- }
684+ for (int k = nz - 2 ; k >= 0 ; k--) {
685+ spectral (i,j,k) -= scratch (i,j,k) * spectral (i,j,k + 1 );
686+ }
641687
642- for (int k = 0 ; k < nz; ++k) {
643- spectral (i,j,k) *= scale;
644- }
645- });
646- Gpu::streamSynchronize ();
688+ for (int k = 0 ; k < nz; ++k) {
689+ spectral (i,j,k) *= scale;
690+ }
691+ });
692+ Gpu::streamSynchronize ();
647693
648694#else
649695
650- Gpu::DeviceVector<T> ald (nz);
651- Gpu::DeviceVector<T> bd (nz);
652- Gpu::DeviceVector<T> cud (nz);
653- Gpu::DeviceVector<T> scratch (nz);
696+ Gpu::DeviceVector<T> ald (nz);
697+ Gpu::DeviceVector<T> bd (nz);
698+ Gpu::DeviceVector<T> cud (nz);
699+ Gpu::DeviceVector<T> scratch (nz);
654700
655- amrex::LoopOnCpu (xybox, [&] (int i, int j, int )
656- {
657- T a = facx*(i+offset[0 ]);
658- T b = facy*(j+offset[1 ]);
659- T k2 = dxfac * (std::cos (a)-T (1 ))
660- + dyfac * (std::cos (b)-T (1 ));
661-
662- // Tridiagonal solve
663- for (int k=0 ; k < nz; k++) {
664- if (k==0 ) {
665- ald[k] = T (0 .);
666- cud[k] = tric (i,j,k);
667- if (zlo_neumann) {
668- bd[k] = k2 - cud[k];
669- } else {
670- bd[k] = k2 - cud[k] - T (2.0 )*tria (i,j,k);
671- }
672- } else if (k == nz-1 ) {
673- ald[k] = tria (i,j,k);
674- cud[k] = T (0 .);
675- if (zhi_neumann) {
676- bd[k] = k2 - ald[k];
677- if (i == 0 && j == 0 && is_singular) {
678- bd[k] *= T (2.0 );
701+ amrex::LoopOnCpu (xybox, [&] (int i, int j, int )
702+ {
703+ T a = facx*(i+offset[0 ]);
704+ T b = facy*(j+offset[1 ]);
705+ T k2 = dxfac * (std::cos (a)-T (1 ))
706+ + dyfac * (std::cos (b)-T (1 ));
707+
708+ // Tridiagonal solve
709+ for (int k=0 ; k < nz; k++) {
710+ if (k==0 ) {
711+ ald[k] = T (0 .);
712+ cud[k] = tric (i,j,k);
713+ if (zlo_neumann) {
714+ bd[k] = k2 - cud[k];
715+ } else {
716+ bd[k] = k2 - cud[k] - T (2.0 )*tria (i,j,k);
717+ }
718+ } else if (k == nz-1 ) {
719+ ald[k] = tria (i,j,k);
720+ cud[k] = T (0 .);
721+ if (zhi_neumann) {
722+ bd[k] = k2 - ald[k];
723+ if (i == 0 && j == 0 && is_singular) {
724+ bd[k] *= T (2.0 );
725+ }
726+ } else {
727+ bd[k] = k2 - ald[k] - T (2.0 )*tric (i,j,k);
679728 }
680729 } else {
681- bd[k] = k2 - ald[k] - T (2.0 )*tric (i,j,k);
730+ ald[k] = tria (i,j,k);
731+ cud[k] = tric (i,j,k);
732+ bd[k] = k2 -ald[k]-cud[k];
682733 }
683- } else {
684- ald[k] = tria (i,j,k);
685- cud[k] = tric (i,j,k);
686- bd[k] = k2 -ald[k]-cud[k];
687734 }
688- }
689735
690- scratch[0 ] = cud[0 ]/bd[0 ];
691- spectral (i,j,0 ) = spectral (i,j,0 )/bd[0 ];
736+ scratch[0 ] = cud[0 ]/bd[0 ];
737+ spectral (i,j,0 ) = spectral (i,j,0 )/bd[0 ];
692738
693- for (int k = 1 ; k < nz; k++) {
694- if (k < nz-1 ) {
695- scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1 ]);
739+ for (int k = 1 ; k < nz; k++) {
740+ if (k < nz-1 ) {
741+ scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1 ]);
742+ }
743+ spectral (i,j,k) = (spectral (i,j,k) - ald[k] * spectral (i,j,k - 1 ))
744+ / (bd[k] - ald[k] * scratch[k-1 ]);
696745 }
697- spectral (i,j,k) = (spectral (i,j,k) - ald[k] * spectral (i,j,k - 1 ))
698- / (bd[k] - ald[k] * scratch[k-1 ]);
699- }
700746
701- for (int k = nz - 2 ; k >= 0 ; k--) {
702- spectral (i,j,k) -= scratch[k] * spectral (i,j,k + 1 );
703- }
747+ for (int k = nz - 2 ; k >= 0 ; k--) {
748+ spectral (i,j,k) -= scratch[k] * spectral (i,j,k + 1 );
749+ }
704750
705- for (int k = 0 ; k < nz; ++k) {
706- spectral (i,j,k) *= scale;
707- }
708- });
751+ for (int k = 0 ; k < nz; ++k) {
752+ spectral (i,j,k) *= scale;
753+ }
754+ });
709755#endif
756+ }
710757 }
711758#endif
712759}
0 commit comments