Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ object. Therefore, one should cache it for reuse if possible. Although
Poisson Solver
==============

AMReX provides FFT based Poisson solvers. :cpp:`FFT::Poisson` supports
periodic (:cpp:`FFT::Boundary::periodic`), homogeneous Neumann
(:cpp:`FFT::Boundary::even`), and homogeneous Dirichlet
AMReX provides FFT based Poisson solvers. Here, Poisson's equation is

.. math::

\nabla^2 \phi = \rho.

:cpp:`FFT::Poisson` supports periodic (:cpp:`FFT::Boundary::periodic`),
homogeneous Neumann (:cpp:`FFT::Boundary::even`), and homogeneous Dirichlet
(:cpp:`FFT::Boundary::odd`) boundaries using FFT. Below is an example of
using the solver.

Expand Down Expand Up @@ -103,6 +108,27 @@ using the solver.
FFT::Poisson fft_poisson(geom, fft_bc);
fft_poisson.solve(soln, rhs);

:cpp:`FFT::PoissonOpenBC` is a 3D only solver that supports open
boundaries. Its implementation utilizes :cpp:`FFT::OpenBCSolver`, which can
be used for implementing convolution based solvers with a user provided
Green's function. If users want to extend the open BC solver to 2D or other
types of Green's function, they could use :cpp:`FFT::PoissonOpenBC` as an
example. Below is an example of solving Poisson's equation with open
boundaries.

.. highlight:: c++

::

Geometry geom(...);
MultiFab soln(...); // soln can be either nodal or cell-centered.
MultiFab rhs(...); // rhs must have the same index type as soln.

int ng = ...; // ng can be non-zero, if we want to compute potential
// outside the domain.
FFT::PoissonOpenBC openbc_solver(geom, soln.ixType(), IntVect(ng));
openbc_solver.solve(soln, rhs);

:cpp:`FFT::PoissonHybrid` is a 3D only solver that supports periodic
boundaries in the first two dimensions and Neumann boundary in the last
dimension. The last dimension is solved with a tridiagonal solver that can
Expand Down
3 changes: 2 additions & 1 deletion Src/Base/AMReX_BoxArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ namespace amrex
*/
[[nodiscard]] BoxArray decompose (Box const& domain, int nboxes,
Array<bool,AMREX_SPACEDIM> const& decomp
= {AMREX_D_DECL(true,true,true)});
= {AMREX_D_DECL(true,true,true)},
bool no_overlap = false);

struct BARef
{
Expand Down
21 changes: 18 additions & 3 deletions Src/Base/AMReX_BoxArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,7 @@ bool match (const BoxArray& x, const BoxArray& y)
}

BoxArray decompose (Box const& domain, int nboxes,
Array<bool,AMREX_SPACEDIM> const& decomp)
Array<bool,AMREX_SPACEDIM> const& decomp, bool no_overlap)
{
auto ndecomp = std::count(decomp.begin(), decomp.end(), true);

Expand Down Expand Up @@ -2048,9 +2048,24 @@ BoxArray decompose (Box const& domain, int nboxes,
ilo += domlo[0];
ihi += domlo[0];
Box b{IntVect(AMREX_D_DECL(ilo,jlo,klo)),
IntVect(AMREX_D_DECL(ihi,jhi,khi))};
IntVect(AMREX_D_DECL(ihi,jhi,khi)), ixtyp};
if (b.ok()) {
bl.push_back(b.convert(ixtyp));
if (no_overlap) {
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
if (ixtyp.nodeCentered(idim) &&
b.bigEnd(idim) == ccdomain.bigEnd(idim))
{
b.growHi(idim, 1);
}
}
} else {
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
if (ixtyp.nodeCentered(idim)) {
b.growHi(idim, 1);
}
}
}
bl.push_back(b);
}
AMREX_D_TERM(},},})

Expand Down
1 change: 1 addition & 0 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define AMREX_FFT_H_
#include <AMReX_Config.H>

#include <AMReX_FFT_OpenBCSolver.H>
#include <AMReX_FFT_R2C.H>
#include <AMReX_FFT_R2X.H>

Expand Down
142 changes: 142 additions & 0 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#ifndef AMREX_FFT_OPENBC_SOLVER_H_
#define AMREX_FFT_OPENBC_SOLVER_H_

#include <AMReX_FFT_R2C.H>

#include <AMReX_VisMF.H>

namespace amrex::FFT
{

template <typename T = Real>
class OpenBCSolver
{
public:
using MF = typename R2C<T>::MF;
using cMF = typename R2C<T>::cMF;

explicit OpenBCSolver (Box const& domain);

template <class F>
void setGreensFunction (F const& greens_function);

void solve (MF& phi, MF const& rho);

[[nodiscard]] Box const& Domain () const { return m_domain; }

private:
Box m_domain;
R2C<T> m_r2c;
cMF m_G_fft;
};

template <typename T>
OpenBCSolver<T>::OpenBCSolver (Box const& domain)
: m_domain(domain),
m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType()))
{
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
}

template <typename T>
template <class F>
void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
{
auto* infab = detail::get_fab(m_r2c.m_rx);
auto const& lo = m_domain.smallEnd();
auto const& lo3 = lo.dim3();
auto const& len = m_domain.length3d();
if (infab) {
auto const& a = infab->array();
auto box = infab->box();
GpuArray<int,3> nimages{1,1,1};
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
nimages[idim] = 2;
}
}
AMREX_ASSERT(nimages[0] == 2);
box.shift(-lo);
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
{
if (i == len[0] || j == len[1] || k == len[2]) {
a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0);
} else {
auto ii = i;
auto jj = (j > len[1]) ? 2*len[1]-j : j;
auto kk = (k > len[2]) ? 2*len[2]-k : k;
auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
for (int koff = 0; koff < nimages[2]; ++koff) {
int k2 = (koff == 0) ? k : 2*len[2]-k;
if (k2 == 2*len[2]) { continue; }
for (int joff = 0; joff < nimages[1]; ++joff) {
int j2 = (joff == 0) ? j : 2*len[1]-j;
if (j2 == 2*len[1]) { continue; }
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
int i2 = (ioff == 0) ? i : 2*len[0]-i;
if (i2 == 2*len[0]) { continue; }
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
}
}
}
}
});
}

m_r2c.forward(m_r2c.m_rx);

auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto const* srcfab = detail::get_fab(*sd);
if (srcfab) {
auto* dstfab = detail::get_fab(m_G_fft);
if (dstfab) {
#if defined(AMREX_USE_GPU)
Gpu::dtod_memcpy_async
#else
std::memcpy
#endif
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
} else {
amrex::Abort("FFT::OpenBCSolver: how did this happen");
}
}
}

template <typename T>
void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
{
auto& inmf = m_r2c.m_rx;
inmf.setVal(T(0));
inmf.ParallelCopy(rho, 0, 0, 1);

m_r2c.forward(inmf);

auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts());

auto const* gfab = detail::get_fab(m_G_fft);
if (gfab) {
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto* rhofab = detail::get_fab(*sd);
if (rhofab) {
auto* pdst = rhofab->dataPtr();
auto const* psrc = gfab->dataPtr();
amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i)
{
pdst[i] *= psrc[i] * scaling_factor;
});
} else {
amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?");
}
}

m_r2c.backward_doit(phi, phi.nGrowVect());
}

}

#endif
81 changes: 80 additions & 1 deletion Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace amrex::FFT
{

/**
* \brief Poisson solver for all periodic boundaries using FFT
* \brief Poisson solver for periodic, Dirichlet & Neumann boundaries using
* FFT.
*/
template <typename MF = MultiFab>
class Poisson
Expand Down Expand Up @@ -40,6 +41,32 @@ private:
R2X<typename MF::value_type> m_r2x;
};

#if (AMREX_SPACEDIM == 3)
/**
* \brief Poisson solve for Open BC using FFT.
*/
template <typename MF = MultiFab>
class PoissonOpenBC
{
public:

template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
explicit PoissonOpenBC (Geometry const& geom,
IndexType ixtype = IndexType::TheCellType(),
IntVect const& ngrow = IntVect(0));

void solve (MF& soln, MF const& rhs);

void define_doit (); // has to be public for cuda

private:
Geometry m_geom;
Box m_grown_domain;
IntVect m_ngrow;
OpenBCSolver<typename MF::value_type> m_solver;
};
#endif

/**
* \brief 3D Poisson solver for periodic boundaries in the first two
* dimensions and Neumann in the last dimension.
Expand Down Expand Up @@ -123,6 +150,58 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
});
}

#if (AMREX_SPACEDIM == 3)

template <typename MF>
template <typename FA, std::enable_if_t<IsFabArray_v<FA>,int> FOO>
PoissonOpenBC<MF>::PoissonOpenBC (Geometry const& geom, IndexType ixtype,
IntVect const& ngrow)
: m_geom(geom),
m_grown_domain(amrex::grow(amrex::convert(geom.Domain(),ixtype),ngrow)),
m_ngrow(ngrow),
m_solver(m_grown_domain)
{
define_doit();
}

template <typename MF>
void PoissonOpenBC<MF>::define_doit ()
{
using T = typename MF::value_type;
auto const& lo = m_grown_domain.smallEnd();
auto const dx = T(m_geom.CellSize(0));
auto const dy = T(m_geom.CellSize(1));
auto const dz = T(m_geom.CellSize(2));
auto const gfac = T(1)/T(std::sqrt(T(12)));
// 0.125 comes from that there are 8 Gauss quadrature points
auto const fac = T(-0.125) * (dx*dy*dz) / (T(4)*Math::pi<T>());
m_solver.setGreensFunction([=] AMREX_GPU_DEVICE (int i, int j, int k) -> T
{
auto x = (T(i-lo[0]) - gfac) * dx; // first Gauss quadrature point
auto y = (T(j-lo[1]) - gfac) * dy;
auto z = (T(k-lo[2]) - gfac) * dz;
T r = 0;
for (int gx = 0; gx < 2; ++gx) {
for (int gy = 0; gy < 2; ++gy) {
for (int gz = 0; gz < 2; ++gz) {
auto xg = x + 2*gx*gfac*dx;
auto yg = y + 2*gy*gfac*dy;
auto zg = z + 2*gz*gfac*dz;
r += T(1)/std::sqrt(xg*xg+yg*yg+zg*zg);
}}}
return fac * r;
});
}

template <typename MF>
void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
{
AMREX_ASSERT(m_grown_domain.ixType() == soln.ixType() && m_grown_domain.ixType() == rhs.ixType());
m_solver.solve(soln, rhs);
}

#endif /* AMREX_SPACEDIM == 3 */

template <typename MF>
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
{
Expand Down
Loading