Skip to content

Commit a3892de

Browse files
authored
FFT hybrid Poisson solver: Add support for batched 2d solves (#4732)
This adds a new function `solve_2d` to the 3d hybrid Poisson solver. It solves 2d Poisson problems at each z.
1 parent 3446dfb commit a3892de

File tree

1 file changed

+160
-113
lines changed

1 file changed

+160
-113
lines changed

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 160 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ public:
193193
void solve (MF& soln, MF const& rhs, Vector<T> const& dz);
194194
void solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz);
195195

196+
void solve_2d (MF& a_soln, MF const& a_rhs);
197+
196198
template <typename TRIA, typename TRIC>
197199
void solve (MF& a_soln, MF const& a_rhs, TRIA const& tria, TRIC const& tric);
198200

@@ -350,6 +352,14 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
350352
#endif /* AMREX_SPACEDIM == 3 */
351353

352354
namespace fft_poisson_detail {
355+
template <typename T>
356+
struct Tri_Zero {
357+
[[nodiscard]] constexpr T operator() (int, int, int) const
358+
{
359+
return 0;
360+
}
361+
};
362+
353363
template <typename T>
354364
struct Tri_Uniform {
355365
[[nodiscard]] AMREX_GPU_DEVICE AMREX_FORCE_INLINE
@@ -480,6 +490,12 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
480490
fft_poisson_detail::TriC<T>{pdz,int(dz.size())});
481491
}
482492

493+
template <typename MF>
494+
void PoissonHybrid<MF>::solve_2d (MF& soln, MF const& rhs)
495+
{
496+
solve(soln, rhs, fft_poisson_detail::Tri_Zero<T>{}, fft_poisson_detail::Tri_Zero<T>{});
497+
}
498+
483499
template <typename MF>
484500
template <typename TRIA, typename TRIC>
485501
void PoissonHybrid<MF>::solve (MF& a_soln, MF const& a_rhs, TRIA const& tria,
@@ -565,148 +581,179 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
565581
}
566582
}
567583

568-
bool zlo_neumann = m_bc[2].first == Boundary::even;
569-
bool zhi_neumann = m_bc[2].second == Boundary::even;
570-
bool is_singular = (offset[0] == T(0)) && (offset[1] == T(0))
571-
&& zlo_neumann && zhi_neumann;
584+
if
585+
#ifndef _WIN32
586+
constexpr
587+
#endif
588+
(std::is_same_v<TRIA,fft_poisson_detail::Tri_Zero<T>> &&
589+
std::is_same_v<TRIC,fft_poisson_detail::Tri_Zero<T>>) {
590+
amrex::ignore_unused(tria,tric);
591+
#if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
592+
#pragma omp parallel
593+
#endif
594+
for (MFIter mfi(spmf,TilingIfNotGPU()); mfi.isValid(); ++mfi)
595+
{
596+
auto const& spectral = spmf.array(mfi);
597+
auto const& box = mfi.validbox();
598+
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
599+
{
600+
T a = facx*(i+offset[0]);
601+
T b = facy*(j+offset[1]);
602+
T k2 = dxfac * (std::cos(a)-T(1))
603+
+ dyfac * (std::cos(b)-T(1));
604+
if (k2 != T(0)) {
605+
spectral(i,j,k) /= k2;
606+
}
607+
spectral(i,j,k) *= scale;
608+
});
609+
}
610+
} else {
611+
bool zlo_neumann = m_bc[2].first == Boundary::even;
612+
bool zhi_neumann = m_bc[2].second == Boundary::even;
613+
bool is_singular = (offset[0] == T(0)) && (offset[1] == T(0))
614+
&& zlo_neumann && zhi_neumann;
572615

573-
auto nz = m_geom.Domain().length(2);
616+
auto nz = m_geom.Domain().length(2);
574617

575-
for (MFIter mfi(spmf); mfi.isValid(); ++mfi)
576-
{
577-
auto const& spectral = spmf.array(mfi);
578-
auto const& box = mfi.validbox();
579-
auto const& xybox = amrex::makeSlab(box, 2, 0);
618+
#if defined(AMREX_USE_OMP) && !defined(AMREX_USE_GPU)
619+
#pragma omp parallel
620+
#endif
621+
for (MFIter mfi(spmf); mfi.isValid(); ++mfi)
622+
{
623+
auto const& spectral = spmf.array(mfi);
624+
auto const& box = mfi.validbox();
625+
auto const& xybox = amrex::makeSlab(box, 2, 0);
580626

581627
#ifdef AMREX_USE_GPU
582-
// xxxxx TODO: We need to explore how to optimize this
583-
// function. Maybe we can use cusparse. Maybe we should make
584-
// z-direction to be the unit stride direction.
628+
// xxxxx TODO: We need to explore how to optimize this
629+
// function. Maybe we can use cusparse. Maybe we should make
630+
// z-direction to be the unit stride direction.
585631

586-
FArrayBox tridiag_workspace(box,4);
587-
auto const& ald = tridiag_workspace.array(0);
588-
auto const& bd = tridiag_workspace.array(1);
589-
auto const& cud = tridiag_workspace.array(2);
590-
auto const& scratch = tridiag_workspace.array(3);
632+
FArrayBox tridiag_workspace(box,4);
633+
auto const& ald = tridiag_workspace.array(0);
634+
auto const& bd = tridiag_workspace.array(1);
635+
auto const& cud = tridiag_workspace.array(2);
636+
auto const& scratch = tridiag_workspace.array(3);
591637

592-
amrex::ParallelFor(xybox, [=] AMREX_GPU_DEVICE (int i, int j, int)
593-
{
594-
T a = facx*(i+offset[0]);
595-
T b = facy*(j+offset[1]);
596-
T k2 = dxfac * (std::cos(a)-T(1))
597-
+ dyfac * (std::cos(b)-T(1));
598-
599-
// Tridiagonal solve
600-
for(int k=0; k < nz; k++) {
601-
if(k==0) {
602-
ald(i,j,k) = T(0.);
603-
cud(i,j,k) = tric(i,j,k);
604-
if (zlo_neumann) {
605-
bd(i,j,k) = k2 - cud(i,j,k);
606-
} else {
607-
bd(i,j,k) = k2 - cud(i,j,k) - T(2.0)*tria(i,j,k);
608-
}
609-
} else if (k == nz-1) {
610-
ald(i,j,k) = tria(i,j,k);
611-
cud(i,j,k) = T(0.);
612-
if (zhi_neumann) {
613-
bd(i,j,k) = k2 - ald(i,j,k);
614-
if (i == 0 && j == 0 && is_singular) {
615-
bd(i,j,k) *= T(2.0);
638+
amrex::ParallelFor(xybox, [=] AMREX_GPU_DEVICE (int i, int j, int)
639+
{
640+
T a = facx*(i+offset[0]);
641+
T b = facy*(j+offset[1]);
642+
T k2 = dxfac * (std::cos(a)-T(1))
643+
+ dyfac * (std::cos(b)-T(1));
644+
645+
// Tridiagonal solve
646+
for(int k=0; k < nz; k++) {
647+
if(k==0) {
648+
ald(i,j,k) = T(0.);
649+
cud(i,j,k) = tric(i,j,k);
650+
if (zlo_neumann) {
651+
bd(i,j,k) = k2 - cud(i,j,k);
652+
} else {
653+
bd(i,j,k) = k2 - cud(i,j,k) - T(2.0)*tria(i,j,k);
654+
}
655+
} else if (k == nz-1) {
656+
ald(i,j,k) = tria(i,j,k);
657+
cud(i,j,k) = T(0.);
658+
if (zhi_neumann) {
659+
bd(i,j,k) = k2 - ald(i,j,k);
660+
if (i == 0 && j == 0 && is_singular) {
661+
bd(i,j,k) *= T(2.0);
662+
}
663+
} else {
664+
bd(i,j,k) = k2 - ald(i,j,k) - T(2.0)*tric(i,j,k);
616665
}
617666
} else {
618-
bd(i,j,k) = k2 - ald(i,j,k) - T(2.0)*tric(i,j,k);
667+
ald(i,j,k) = tria(i,j,k);
668+
cud(i,j,k) = tric(i,j,k);
669+
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
619670
}
620-
} else {
621-
ald(i,j,k) = tria(i,j,k);
622-
cud(i,j,k) = tric(i,j,k);
623-
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
624671
}
625-
}
626672

627-
scratch(i,j,0) = cud(i,j,0)/bd(i,j,0);
628-
spectral(i,j,0) = spectral(i,j,0)/bd(i,j,0);
673+
scratch(i,j,0) = cud(i,j,0)/bd(i,j,0);
674+
spectral(i,j,0) = spectral(i,j,0)/bd(i,j,0);
629675

630-
for (int k = 1; k < nz; k++) {
631-
if (k < nz-1) {
632-
scratch(i,j,k) = cud(i,j,k) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
676+
for (int k = 1; k < nz; k++) {
677+
if (k < nz-1) {
678+
scratch(i,j,k) = cud(i,j,k) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
679+
}
680+
spectral(i,j,k) = (spectral(i,j,k) - ald(i,j,k) * spectral(i,j,k - 1))
681+
/ (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
633682
}
634-
spectral(i,j,k) = (spectral(i,j,k) - ald(i,j,k) * spectral(i,j,k - 1))
635-
/ (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
636-
}
637683

638-
for (int k = nz - 2; k >= 0; k--) {
639-
spectral(i,j,k) -= scratch(i,j,k) * spectral(i,j,k + 1);
640-
}
684+
for (int k = nz - 2; k >= 0; k--) {
685+
spectral(i,j,k) -= scratch(i,j,k) * spectral(i,j,k + 1);
686+
}
641687

642-
for (int k = 0; k < nz; ++k) {
643-
spectral(i,j,k) *= scale;
644-
}
645-
});
646-
Gpu::streamSynchronize();
688+
for (int k = 0; k < nz; ++k) {
689+
spectral(i,j,k) *= scale;
690+
}
691+
});
692+
Gpu::streamSynchronize();
647693

648694
#else
649695

650-
Gpu::DeviceVector<T> ald(nz);
651-
Gpu::DeviceVector<T> bd(nz);
652-
Gpu::DeviceVector<T> cud(nz);
653-
Gpu::DeviceVector<T> scratch(nz);
696+
Gpu::DeviceVector<T> ald(nz);
697+
Gpu::DeviceVector<T> bd(nz);
698+
Gpu::DeviceVector<T> cud(nz);
699+
Gpu::DeviceVector<T> scratch(nz);
654700

655-
amrex::LoopOnCpu(xybox, [&] (int i, int j, int)
656-
{
657-
T a = facx*(i+offset[0]);
658-
T b = facy*(j+offset[1]);
659-
T k2 = dxfac * (std::cos(a)-T(1))
660-
+ dyfac * (std::cos(b)-T(1));
661-
662-
// Tridiagonal solve
663-
for(int k=0; k < nz; k++) {
664-
if(k==0) {
665-
ald[k] = T(0.);
666-
cud[k] = tric(i,j,k);
667-
if (zlo_neumann) {
668-
bd[k] = k2 - cud[k];
669-
} else {
670-
bd[k] = k2 - cud[k] - T(2.0)*tria(i,j,k);
671-
}
672-
} else if (k == nz-1) {
673-
ald[k] = tria(i,j,k);
674-
cud[k] = T(0.);
675-
if (zhi_neumann) {
676-
bd[k] = k2 - ald[k];
677-
if (i == 0 && j == 0 && is_singular) {
678-
bd[k] *= T(2.0);
701+
amrex::LoopOnCpu(xybox, [&] (int i, int j, int)
702+
{
703+
T a = facx*(i+offset[0]);
704+
T b = facy*(j+offset[1]);
705+
T k2 = dxfac * (std::cos(a)-T(1))
706+
+ dyfac * (std::cos(b)-T(1));
707+
708+
// Tridiagonal solve
709+
for(int k=0; k < nz; k++) {
710+
if(k==0) {
711+
ald[k] = T(0.);
712+
cud[k] = tric(i,j,k);
713+
if (zlo_neumann) {
714+
bd[k] = k2 - cud[k];
715+
} else {
716+
bd[k] = k2 - cud[k] - T(2.0)*tria(i,j,k);
717+
}
718+
} else if (k == nz-1) {
719+
ald[k] = tria(i,j,k);
720+
cud[k] = T(0.);
721+
if (zhi_neumann) {
722+
bd[k] = k2 - ald[k];
723+
if (i == 0 && j == 0 && is_singular) {
724+
bd[k] *= T(2.0);
725+
}
726+
} else {
727+
bd[k] = k2 - ald[k] - T(2.0)*tric(i,j,k);
679728
}
680729
} else {
681-
bd[k] = k2 - ald[k] - T(2.0)*tric(i,j,k);
730+
ald[k] = tria(i,j,k);
731+
cud[k] = tric(i,j,k);
732+
bd[k] = k2 -ald[k]-cud[k];
682733
}
683-
} else {
684-
ald[k] = tria(i,j,k);
685-
cud[k] = tric(i,j,k);
686-
bd[k] = k2 -ald[k]-cud[k];
687734
}
688-
}
689735

690-
scratch[0] = cud[0]/bd[0];
691-
spectral(i,j,0) = spectral(i,j,0)/bd[0];
736+
scratch[0] = cud[0]/bd[0];
737+
spectral(i,j,0) = spectral(i,j,0)/bd[0];
692738

693-
for (int k = 1; k < nz; k++) {
694-
if (k < nz-1) {
695-
scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1]);
739+
for (int k = 1; k < nz; k++) {
740+
if (k < nz-1) {
741+
scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1]);
742+
}
743+
spectral(i,j,k) = (spectral(i,j,k) - ald[k] * spectral(i,j,k - 1))
744+
/ (bd[k] - ald[k] * scratch[k-1]);
696745
}
697-
spectral(i,j,k) = (spectral(i,j,k) - ald[k] * spectral(i,j,k - 1))
698-
/ (bd[k] - ald[k] * scratch[k-1]);
699-
}
700746

701-
for (int k = nz - 2; k >= 0; k--) {
702-
spectral(i,j,k) -= scratch[k] * spectral(i,j,k + 1);
703-
}
747+
for (int k = nz - 2; k >= 0; k--) {
748+
spectral(i,j,k) -= scratch[k] * spectral(i,j,k + 1);
749+
}
704750

705-
for (int k = 0; k < nz; ++k) {
706-
spectral(i,j,k) *= scale;
707-
}
708-
});
751+
for (int k = 0; k < nz; ++k) {
752+
spectral(i,j,k) *= scale;
753+
}
754+
});
709755
#endif
756+
}
710757
}
711758
#endif
712759
}

0 commit comments

Comments
 (0)