Skip to content

Commit 59bd96a

Browse files
committed
FFT hybrid Poisson solver: Add support for batched 2d solves
This adds a new function `solve_2d` to the 3d hybrid Poisson solver. It solves 2d Poisson problems at each z.
1 parent 3446dfb commit 59bd96a

File tree

1 file changed

+156
-114
lines changed

1 file changed

+156
-114
lines changed

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 156 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ public:
193193
void solve (MF& soln, MF const& rhs, Vector<T> const& dz);
194194
void solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz);
195195

196+
void solve_2d (MF& a_soln, MF const& a_rhs);
197+
196198
template <typename TRIA, typename TRIC>
197199
void solve (MF& a_soln, MF const& a_rhs, TRIA const& tria, TRIC const& tric);
198200

@@ -350,6 +352,14 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
350352
#endif /* AMREX_SPACEDIM == 3 */
351353

352354
namespace fft_poisson_detail {
355+
template <typename T>
356+
struct Tri_Zero {
357+
static constexpr T operator() (int, int, int)
358+
{
359+
return 0;
360+
}
361+
};
362+
353363
template <typename T>
354364
struct Tri_Uniform {
355365
[[nodiscard]] AMREX_GPU_DEVICE AMREX_FORCE_INLINE
@@ -480,6 +490,12 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
480490
fft_poisson_detail::TriC<T>{pdz,int(dz.size())});
481491
}
482492

493+
template <typename MF>
494+
void PoissonHybrid<MF>::solve_2d (MF& soln, MF const& rhs)
495+
{
496+
solve(soln, rhs, fft_poisson_detail::Tri_Zero<T>{}, fft_poisson_detail::Tri_Zero<T>{});
497+
}
498+
483499
template <typename MF>
484500
template <typename TRIA, typename TRIC>
485501
void PoissonHybrid<MF>::solve (MF& a_soln, MF const& a_rhs, TRIA const& tria,
@@ -565,150 +581,176 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
565581
}
566582
}
567583

568-
bool zlo_neumann = m_bc[2].first == Boundary::even;
569-
bool zhi_neumann = m_bc[2].second == Boundary::even;
570-
bool is_singular = (offset[0] == T(0)) && (offset[1] == T(0))
571-
&& zlo_neumann && zhi_neumann;
584+
if constexpr (std::is_same_v<TRIA,fft_poisson_detail::Tri_Zero<T>> &&
585+
std::is_same_v<TRIC,fft_poisson_detail::Tri_Zero<T>>) {
586+
#if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
587+
#pragma omp parallel
588+
#endif
589+
for (MFIter mfi(spmf,TilingIfNotGPU()); mfi.isValid(); ++mfi)
590+
{
591+
auto const& spectral = spmf.array(mfi);
592+
auto const& box = mfi.validbox();
593+
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
594+
{
595+
T a = facx*(i+offset[0]);
596+
T b = facy*(j+offset[1]);
597+
T k2 = dxfac * (std::cos(a)-T(1))
598+
+ dyfac * (std::cos(b)-T(1));
599+
if (k2 != T(0)) {
600+
spectral(i,j,k) /= k2;
601+
}
602+
spectral(i,j,k) *= scale;
603+
});
604+
}
605+
} else {
606+
bool zlo_neumann = m_bc[2].first == Boundary::even;
607+
bool zhi_neumann = m_bc[2].second == Boundary::even;
608+
bool is_singular = (offset[0] == T(0)) && (offset[1] == T(0))
609+
&& zlo_neumann && zhi_neumann;
572610

573-
auto nz = m_geom.Domain().length(2);
611+
auto nz = m_geom.Domain().length(2);
574612

575-
for (MFIter mfi(spmf); mfi.isValid(); ++mfi)
576-
{
577-
auto const& spectral = spmf.array(mfi);
578-
auto const& box = mfi.validbox();
579-
auto const& xybox = amrex::makeSlab(box, 2, 0);
613+
#if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
614+
#pragma omp parallel
615+
#endif
616+
for (MFIter mfi(spmf); mfi.isValid(); ++mfi)
617+
{
618+
auto const& spectral = spmf.array(mfi);
619+
auto const& box = mfi.validbox();
620+
auto const& xybox = amrex::makeSlab(box, 2, 0);
580621

581622
#ifdef AMREX_USE_GPU
582-
// xxxxx TODO: We need to explore how to optimize this
583-
// function. Maybe we can use cusparse. Maybe we should make
584-
// z-direction to be the unit stride direction.
623+
// xxxxx TODO: We need to explore how to optimize this
624+
// function. Maybe we can use cusparse. Maybe we should make
625+
// z-direction to be the unit stride direction.
585626

586-
FArrayBox tridiag_workspace(box,4);
587-
auto const& ald = tridiag_workspace.array(0);
588-
auto const& bd = tridiag_workspace.array(1);
589-
auto const& cud = tridiag_workspace.array(2);
590-
auto const& scratch = tridiag_workspace.array(3);
627+
FArrayBox tridiag_workspace(box,4);
628+
auto const& ald = tridiag_workspace.array(0);
629+
auto const& bd = tridiag_workspace.array(1);
630+
auto const& cud = tridiag_workspace.array(2);
631+
auto const& scratch = tridiag_workspace.array(3);
591632

592-
amrex::ParallelFor(xybox, [=] AMREX_GPU_DEVICE (int i, int j, int)
593-
{
594-
T a = facx*(i+offset[0]);
595-
T b = facy*(j+offset[1]);
596-
T k2 = dxfac * (std::cos(a)-T(1))
597-
+ dyfac * (std::cos(b)-T(1));
598-
599-
// Tridiagonal solve
600-
for(int k=0; k < nz; k++) {
601-
if(k==0) {
602-
ald(i,j,k) = T(0.);
603-
cud(i,j,k) = tric(i,j,k);
604-
if (zlo_neumann) {
605-
bd(i,j,k) = k2 - cud(i,j,k);
606-
} else {
607-
bd(i,j,k) = k2 - cud(i,j,k) - T(2.0)*tria(i,j,k);
608-
}
609-
} else if (k == nz-1) {
610-
ald(i,j,k) = tria(i,j,k);
611-
cud(i,j,k) = T(0.);
612-
if (zhi_neumann) {
613-
bd(i,j,k) = k2 - ald(i,j,k);
614-
if (i == 0 && j == 0 && is_singular) {
615-
bd(i,j,k) *= T(2.0);
633+
amrex::ParallelFor(xybox, [=] AMREX_GPU_DEVICE (int i, int j, int)
634+
{
635+
T a = facx*(i+offset[0]);
636+
T b = facy*(j+offset[1]);
637+
T k2 = dxfac * (std::cos(a)-T(1))
638+
+ dyfac * (std::cos(b)-T(1));
639+
640+
// Tridiagonal solve
641+
for(int k=0; k < nz; k++) {
642+
if(k==0) {
643+
ald(i,j,k) = T(0.);
644+
cud(i,j,k) = tric(i,j,k);
645+
if (zlo_neumann) {
646+
bd(i,j,k) = k2 - cud(i,j,k);
647+
} else {
648+
bd(i,j,k) = k2 - cud(i,j,k) - T(2.0)*tria(i,j,k);
649+
}
650+
} else if (k == nz-1) {
651+
ald(i,j,k) = tria(i,j,k);
652+
cud(i,j,k) = T(0.);
653+
if (zhi_neumann) {
654+
bd(i,j,k) = k2 - ald(i,j,k);
655+
if (i == 0 && j == 0 && is_singular) {
656+
bd(i,j,k) *= T(2.0);
657+
}
658+
} else {
659+
bd(i,j,k) = k2 - ald(i,j,k) - T(2.0)*tric(i,j,k);
616660
}
617661
} else {
618-
bd(i,j,k) = k2 - ald(i,j,k) - T(2.0)*tric(i,j,k);
662+
ald(i,j,k) = tria(i,j,k);
663+
cud(i,j,k) = tric(i,j,k);
664+
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
619665
}
620-
} else {
621-
ald(i,j,k) = tria(i,j,k);
622-
cud(i,j,k) = tric(i,j,k);
623-
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
624666
}
625-
}
626667

627-
scratch(i,j,0) = cud(i,j,0)/bd(i,j,0);
628-
spectral(i,j,0) = spectral(i,j,0)/bd(i,j,0);
668+
scratch(i,j,0) = cud(i,j,0)/bd(i,j,0);
669+
spectral(i,j,0) = spectral(i,j,0)/bd(i,j,0);
629670

630-
for (int k = 1; k < nz; k++) {
631-
if (k < nz-1) {
632-
scratch(i,j,k) = cud(i,j,k) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
671+
for (int k = 1; k < nz; k++) {
672+
if (k < nz-1) {
673+
scratch(i,j,k) = cud(i,j,k) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
674+
}
675+
spectral(i,j,k) = (spectral(i,j,k) - ald(i,j,k) * spectral(i,j,k - 1))
676+
/ (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
633677
}
634-
spectral(i,j,k) = (spectral(i,j,k) - ald(i,j,k) * spectral(i,j,k - 1))
635-
/ (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
636-
}
637678

638-
for (int k = nz - 2; k >= 0; k--) {
639-
spectral(i,j,k) -= scratch(i,j,k) * spectral(i,j,k + 1);
640-
}
679+
for (int k = nz - 2; k >= 0; k--) {
680+
spectral(i,j,k) -= scratch(i,j,k) * spectral(i,j,k + 1);
681+
}
641682

642-
for (int k = 0; k < nz; ++k) {
643-
spectral(i,j,k) *= scale;
644-
}
645-
});
646-
Gpu::streamSynchronize();
683+
for (int k = 0; k < nz; ++k) {
684+
spectral(i,j,k) *= scale;
685+
}
686+
});
687+
Gpu::streamSynchronize();
647688

648689
#else
649690

650-
Gpu::DeviceVector<T> ald(nz);
651-
Gpu::DeviceVector<T> bd(nz);
652-
Gpu::DeviceVector<T> cud(nz);
653-
Gpu::DeviceVector<T> scratch(nz);
691+
Gpu::DeviceVector<T> ald(nz);
692+
Gpu::DeviceVector<T> bd(nz);
693+
Gpu::DeviceVector<T> cud(nz);
694+
Gpu::DeviceVector<T> scratch(nz);
654695

655-
amrex::LoopOnCpu(xybox, [&] (int i, int j, int)
656-
{
657-
T a = facx*(i+offset[0]);
658-
T b = facy*(j+offset[1]);
659-
T k2 = dxfac * (std::cos(a)-T(1))
660-
+ dyfac * (std::cos(b)-T(1));
661-
662-
// Tridiagonal solve
663-
for(int k=0; k < nz; k++) {
664-
if(k==0) {
665-
ald[k] = T(0.);
666-
cud[k] = tric(i,j,k);
667-
if (zlo_neumann) {
668-
bd[k] = k2 - cud[k];
669-
} else {
670-
bd[k] = k2 - cud[k] - T(2.0)*tria(i,j,k);
671-
}
672-
} else if (k == nz-1) {
673-
ald[k] = tria(i,j,k);
674-
cud[k] = T(0.);
675-
if (zhi_neumann) {
676-
bd[k] = k2 - ald[k];
677-
if (i == 0 && j == 0 && is_singular) {
678-
bd[k] *= T(2.0);
696+
amrex::LoopOnCpu(xybox, [&] (int i, int j, int)
697+
{
698+
T a = facx*(i+offset[0]);
699+
T b = facy*(j+offset[1]);
700+
T k2 = dxfac * (std::cos(a)-T(1))
701+
+ dyfac * (std::cos(b)-T(1));
702+
703+
// Tridiagonal solve
704+
for(int k=0; k < nz; k++) {
705+
if(k==0) {
706+
ald[k] = T(0.);
707+
cud[k] = tric(i,j,k);
708+
if (zlo_neumann) {
709+
bd[k] = k2 - cud[k];
710+
} else {
711+
bd[k] = k2 - cud[k] - T(2.0)*tria(i,j,k);
712+
}
713+
} else if (k == nz-1) {
714+
ald[k] = tria(i,j,k);
715+
cud[k] = T(0.);
716+
if (zhi_neumann) {
717+
bd[k] = k2 - ald[k];
718+
if (i == 0 && j == 0 && is_singular) {
719+
bd[k] *= T(2.0);
720+
}
721+
} else {
722+
bd[k] = k2 - ald[k] - T(2.0)*tric(i,j,k);
679723
}
680724
} else {
681-
bd[k] = k2 - ald[k] - T(2.0)*tric(i,j,k);
725+
ald[k] = tria(i,j,k);
726+
cud[k] = tric(i,j,k);
727+
bd[k] = k2 -ald[k]-cud[k];
682728
}
683-
} else {
684-
ald[k] = tria(i,j,k);
685-
cud[k] = tric(i,j,k);
686-
bd[k] = k2 -ald[k]-cud[k];
687729
}
688-
}
689730

690-
scratch[0] = cud[0]/bd[0];
691-
spectral(i,j,0) = spectral(i,j,0)/bd[0];
731+
scratch[0] = cud[0]/bd[0];
732+
spectral(i,j,0) = spectral(i,j,0)/bd[0];
692733

693-
for (int k = 1; k < nz; k++) {
694-
if (k < nz-1) {
695-
scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1]);
734+
for (int k = 1; k < nz; k++) {
735+
if (k < nz-1) {
736+
scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1]);
737+
}
738+
spectral(i,j,k) = (spectral(i,j,k) - ald[k] * spectral(i,j,k - 1))
739+
/ (bd[k] - ald[k] * scratch[k-1]);
696740
}
697-
spectral(i,j,k) = (spectral(i,j,k) - ald[k] * spectral(i,j,k - 1))
698-
/ (bd[k] - ald[k] * scratch[k-1]);
699-
}
700741

701-
for (int k = nz - 2; k >= 0; k--) {
702-
spectral(i,j,k) -= scratch[k] * spectral(i,j,k + 1);
703-
}
742+
for (int k = nz - 2; k >= 0; k--) {
743+
spectral(i,j,k) -= scratch[k] * spectral(i,j,k + 1);
744+
}
704745

705-
for (int k = 0; k < nz; ++k) {
706-
spectral(i,j,k) *= scale;
707-
}
708-
});
746+
for (int k = 0; k < nz; ++k) {
747+
spectral(i,j,k) *= scale;
748+
}
749+
});
709750
#endif
710-
}
751+
}
711752
#endif
753+
}
712754
}
713755

714756
namespace detail {

0 commit comments

Comments
 (0)