Skip to content

Commit 5a260a7

Browse files
committed
FFT: OpenBC Solver
This implements the algorithm of Hockney, Methods Comp. Phys. 9 (1970), 136-210 for solving Poisson's equation with open boundaries.
1 parent 34d3eda commit 5a260a7

File tree

13 files changed

+464
-25
lines changed

13 files changed

+464
-25
lines changed

Docs/sphinx_documentation/source/FFT.rst

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,14 @@ object. Therefore, one should cache it for reuse if possible. Although
7171
Poisson Solver
7272
==============
7373

74-
AMReX provides FFT based Poisson solvers. :cpp:`FFT::Poisson` supports
75-
periodic (:cpp:`FFT::Boundary::periodic`), homogeneous Neumann
76-
(:cpp:`FFT::Boundary::even`), and homogeneous Dirichlet
74+
AMReX provides FFT based Poisson solvers. Here, Poisson's equation is
75+
76+
.. math::
77+
78+
\nabla^2 \phi = \rho.
79+
80+
:cpp:`FFT::Poisson` supports periodic (:cpp:`FFT::Boundary::periodic`),
81+
homogeneous Neumann (:cpp:`FFT::Boundary::even`), and homogeneous Dirichlet
7782
(:cpp:`FFT::Boundary::odd`) boundaries using FFT. Below is an example of
7883
using the solver.
7984

@@ -103,6 +108,27 @@ using the solver.
103108
FFT::Poisson fft_poisson(geom, fft_bc);
104109
fft_poisson.solve(soln, rhs);
105110

111+
:cpp:`FFT::PoissonOpenBC` is a 3D only solver that supports open
112+
boundaries. Its implementation utilizes :cpp:`FFT::OpenBCSolver`, which can
113+
be used for implementing convolution based solvers with a user provided
114+
Green's function. If users want to extend the open BC solver to 2D or other
115+
types of Green's function, they could use :cpp:`FFT::PoissonOpenBC` as an
116+
example. Below is an example of solving Poisson's equation with open
117+
boundaries.
118+
119+
.. highlight:: c++
120+
121+
::
122+
123+
Geometry geom(...);
124+
MultiFab soln(...); // soln can be either nodal or cell-centered.
125+
MultiFab rhs(...); // rhs must have the same index type as soln.
126+
127+
int ng = ...; // ng can be non-zero, if we want to compute potential
128+
// outside the domain.
129+
FFT::PoissonOpenBC openbc_solver(geom, soln.ixType(), IntVect(ng));
130+
openbc_solver.solve(soln, rhs);
131+
106132
:cpp:`FFT::PoissonHybrid` is a 3D only solver that supports periodic
107133
boundaries in the first two dimensions and Neumann boundary in the last
108134
dimension. The last dimension is solved with a tridiagonal solver that can

Src/Base/AMReX_BoxArray.H

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ namespace amrex
6969
*/
7070
[[nodiscard]] BoxArray decompose (Box const& domain, int nboxes,
7171
Array<bool,AMREX_SPACEDIM> const& decomp
72-
= {AMREX_D_DECL(true,true,true)});
72+
= {AMREX_D_DECL(true,true,true)},
73+
bool no_overlap = false);
7374

7475
struct BARef
7576
{

Src/Base/AMReX_BoxArray.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,7 @@ bool match (const BoxArray& x, const BoxArray& y)
18911891
}
18921892

18931893
BoxArray decompose (Box const& domain, int nboxes,
1894-
Array<bool,AMREX_SPACEDIM> const& decomp)
1894+
Array<bool,AMREX_SPACEDIM> const& decomp, bool no_overlap)
18951895
{
18961896
auto ndecomp = std::count(decomp.begin(), decomp.end(), true);
18971897

@@ -2048,9 +2048,24 @@ BoxArray decompose (Box const& domain, int nboxes,
20482048
ilo += domlo[0];
20492049
ihi += domlo[0];
20502050
Box b{IntVect(AMREX_D_DECL(ilo,jlo,klo)),
2051-
IntVect(AMREX_D_DECL(ihi,jhi,khi))};
2051+
IntVect(AMREX_D_DECL(ihi,jhi,khi)), ixtyp};
20522052
if (b.ok()) {
2053-
bl.push_back(b.convert(ixtyp));
2053+
if (no_overlap) {
2054+
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
2055+
if (ixtyp.nodeCentered(idim) &&
2056+
b.bigEnd(idim) == ccdomain.bigEnd(idim))
2057+
{
2058+
b.growHi(idim, 1);
2059+
}
2060+
}
2061+
} else {
2062+
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
2063+
if (ixtyp.nodeCentered(idim)) {
2064+
b.growHi(idim, 1);
2065+
}
2066+
}
2067+
}
2068+
bl.push_back(b);
20542069
}
20552070
AMREX_D_TERM(},},})
20562071

Src/FFT/AMReX_FFT.H

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define AMREX_FFT_H_
33
#include <AMReX_Config.H>
44

5+
#include <AMReX_FFT_OpenBCSolver.H>
56
#include <AMReX_FFT_R2C.H>
67
#include <AMReX_FFT_R2X.H>
78

Src/FFT/AMReX_FFT_OpenBCSolver.H

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#ifndef AMREX_FFT_OPENBC_SOLVER_H_
2+
#define AMREX_FFT_OPENBC_SOLVER_H_
3+
4+
#include <AMReX_FFT_R2C.H>
5+
6+
#include <AMReX_VisMF.H>
7+
8+
namespace amrex::FFT
9+
{
10+
11+
template <typename T = Real>
12+
class OpenBCSolver
13+
{
14+
public:
15+
using MF = typename R2C<T>::MF;
16+
using cMF = typename R2C<T>::cMF;
17+
18+
explicit OpenBCSolver (Box const& domain);
19+
20+
template <class F>
21+
void setGreensFunction (F const& greens_function);
22+
23+
void solve (MF& phi, MF const& rho);
24+
25+
[[nodiscard]] Box const& Domain () const { return m_domain; }
26+
27+
private:
28+
Box m_domain;
29+
R2C<T> m_r2c;
30+
cMF m_G_fft;
31+
};
32+
33+
template <typename T>
34+
OpenBCSolver<T>::OpenBCSolver (Box const& domain)
35+
: m_domain(domain),
36+
m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType()))
37+
{
38+
auto [sd, ord] = m_r2c.getSpectralData();
39+
amrex::ignore_unused(ord);
40+
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
41+
}
42+
43+
template <typename T>
44+
template <class F>
45+
void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
46+
{
47+
auto* infab = detail::get_fab(m_r2c.m_rx);
48+
auto const& lo = m_domain.smallEnd();
49+
auto const& lo3 = lo.dim3();
50+
auto const& len = m_domain.length3d();
51+
if (infab) {
52+
auto const& a = infab->array();
53+
auto box = infab->box();
54+
GpuArray<int,3> nimages{1,1,1};
55+
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
56+
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
57+
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
58+
nimages[idim] = 2;
59+
}
60+
}
61+
AMREX_ASSERT(nimages[0] == 2);
62+
box.shift(-lo);
63+
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
64+
{
65+
if (i == len[0] || j == len[1] || k == len[2]) {
66+
a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0);
67+
} else {
68+
auto ii = i;
69+
auto jj = (j > len[1]) ? 2*len[1]-j : j;
70+
auto kk = (k > len[2]) ? 2*len[2]-k : k;
71+
auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
72+
for (int koff = 0; koff < nimages[2]; ++koff) {
73+
int k2 = (koff == 0) ? k : 2*len[2]-k;
74+
if (k2 == 2*len[2]) { continue; }
75+
for (int joff = 0; joff < nimages[1]; ++joff) {
76+
int j2 = (joff == 0) ? j : 2*len[1]-j;
77+
if (j2 == 2*len[1]) { continue; }
78+
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
79+
int i2 = (ioff == 0) ? i : 2*len[0]-i;
80+
if (i2 == 2*len[0]) { continue; }
81+
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
82+
}
83+
}
84+
}
85+
}
86+
});
87+
}
88+
89+
m_r2c.forward(m_r2c.m_rx);
90+
91+
auto [sd, ord] = m_r2c.getSpectralData();
92+
amrex::ignore_unused(ord);
93+
auto const* srcfab = detail::get_fab(*sd);
94+
if (srcfab) {
95+
auto* dstfab = detail::get_fab(m_G_fft);
96+
if (dstfab) {
97+
#if defined(AMREX_USE_GPU)
98+
Gpu::dtod_memcpy_async
99+
#else
100+
std::memcpy
101+
#endif
102+
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
103+
} else {
104+
amrex::Abort("FFT::OpenBCSolver: how did this happen");
105+
}
106+
}
107+
}
108+
109+
template <typename T>
110+
void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
111+
{
112+
auto& inmf = m_r2c.m_rx;
113+
inmf.setVal(T(0));
114+
inmf.ParallelCopy(rho, 0, 0, 1);
115+
116+
m_r2c.forward(inmf);
117+
118+
auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts());
119+
120+
auto const* gfab = detail::get_fab(m_G_fft);
121+
if (gfab) {
122+
auto [sd, ord] = m_r2c.getSpectralData();
123+
amrex::ignore_unused(ord);
124+
auto* rhofab = detail::get_fab(*sd);
125+
if (rhofab) {
126+
auto* pdst = rhofab->dataPtr();
127+
auto const* psrc = gfab->dataPtr();
128+
amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i)
129+
{
130+
pdst[i] *= psrc[i] * scaling_factor;
131+
});
132+
} else {
133+
amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?");
134+
}
135+
}
136+
137+
m_r2c.backward_doit(phi, phi.nGrowVect());
138+
}
139+
140+
}
141+
142+
#endif

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ namespace amrex::FFT
88
{
99

1010
/**
11-
* \brief Poisson solver for all periodic boundaries using FFT
11+
* \brief Poisson solver for periodic, Dirichlet & Neumann boundaries using
12+
* FFT.
1213
*/
1314
template <typename MF = MultiFab>
1415
class Poisson
@@ -40,6 +41,32 @@ private:
4041
R2X<typename MF::value_type> m_r2x;
4142
};
4243

44+
#if (AMREX_SPACEDIM == 3)
45+
/**
46+
* \brief Poisson solve for Open BC using FFT.
47+
*/
48+
template <typename MF = MultiFab>
49+
class PoissonOpenBC
50+
{
51+
public:
52+
53+
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
54+
explicit PoissonOpenBC (Geometry const& geom,
55+
IndexType ixtype = IndexType::TheCellType(),
56+
IntVect const& ngrow = IntVect(0));
57+
58+
void solve (MF& soln, MF const& rhs);
59+
60+
void define_doit (); // has to be public for cuda
61+
62+
private:
63+
Geometry m_geom;
64+
Box m_grown_domain;
65+
IntVect m_ngrow;
66+
OpenBCSolver<typename MF::value_type> m_solver;
67+
};
68+
#endif
69+
4370
/**
4471
* \brief 3D Poisson solver for periodic boundaries in the first two
4572
* dimensions and Neumann in the last dimension.
@@ -123,6 +150,58 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
123150
});
124151
}
125152

153+
#if (AMREX_SPACEDIM == 3)
154+
155+
template <typename MF>
156+
template <typename FA, std::enable_if_t<IsFabArray_v<FA>,int> FOO>
157+
PoissonOpenBC<MF>::PoissonOpenBC (Geometry const& geom, IndexType ixtype,
158+
IntVect const& ngrow)
159+
: m_geom(geom),
160+
m_grown_domain(amrex::grow(amrex::convert(geom.Domain(),ixtype),ngrow)),
161+
m_ngrow(ngrow),
162+
m_solver(m_grown_domain)
163+
{
164+
define_doit();
165+
}
166+
167+
template <typename MF>
168+
void PoissonOpenBC<MF>::define_doit ()
169+
{
170+
using T = typename MF::value_type;
171+
auto const& lo = m_grown_domain.smallEnd();
172+
auto const dx = T(m_geom.CellSize(0));
173+
auto const dy = T(m_geom.CellSize(1));
174+
auto const dz = T(m_geom.CellSize(2));
175+
auto const gfac = T(1)/T(std::sqrt(T(12)));
176+
// 0.125 comes from that there are 8 Gauss quadrature points
177+
auto const fac = T(-0.125) * (dx*dy*dz) / (T(4)*Math::pi<T>());
178+
m_solver.setGreensFunction([=] AMREX_GPU_DEVICE (int i, int j, int k) -> T
179+
{
180+
auto x = (T(i-lo[0]) - gfac) * dx; // first Gauss quadrature point
181+
auto y = (T(j-lo[1]) - gfac) * dy;
182+
auto z = (T(k-lo[2]) - gfac) * dz;
183+
T r = 0;
184+
for (int gx = 0; gx < 2; ++gx) {
185+
for (int gy = 0; gy < 2; ++gy) {
186+
for (int gz = 0; gz < 2; ++gz) {
187+
auto xg = x + 2*gx*gfac*dx;
188+
auto yg = y + 2*gy*gfac*dy;
189+
auto zg = z + 2*gz*gfac*dz;
190+
r += T(1)/std::sqrt(xg*xg+yg*yg+zg*zg);
191+
}}}
192+
return fac * r;
193+
});
194+
}
195+
196+
template <typename MF>
197+
void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
198+
{
199+
AMREX_ASSERT(m_grown_domain.ixType() == soln.ixType() && m_grown_domain.ixType() == rhs.ixType());
200+
m_solver.solve(soln, rhs);
201+
}
202+
203+
#endif /* AMREX_SPACEDIM == 3 */
204+
126205
template <typename MF>
127206
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
128207
{

0 commit comments

Comments
 (0)