Skip to content

Commit d4ac5b3

Browse files
authored
FFT Poisson Solver: Neumann and Dirichlet Boundaries (#4202)
Add support for Neumann and Dirichlet boundaries in the FFT based Poisson solver. This requires cosine and sine transforms. For CPU builds, we use FFTW for these transforms. For GPU builds, we have implemented cosine and sine transforms using the real-to-complex transform provided by cuFFT, rocFFT and oneMKL.
1 parent 00e6f75 commit d4ac5b3

File tree

16 files changed

+2849
-1070
lines changed

16 files changed

+2849
-1070
lines changed

Docs/sphinx_documentation/source/FFT.rst

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ Below are examples of using :cpp:`FFT:R2C`.
4747
sp *= scaling;
4848
});
4949

50-
cMultiFab cmf(...);
50+
// Use R2C provided spectral data layout.
51+
auto const& [cba, cdm] = r2c.getSpectralDataLayout();
52+
cMultiFab cmf(cba, cdm, 1, 0);
5153
FFT::R2C<Real,FFT::Direction::forward> r2c_forward(geom.Domain());
5254
r2c_forward(mfin, cmf);
5355

@@ -56,16 +58,57 @@ Below are examples of using :cpp:`FFT:R2C`.
5658

5759
Note that using :cpp:`forwardThenBackward` is expected to be more efficient
5860
than separate calls to :cpp:`forward` and :cpp:`backward` because some
59-
parallel communication can be avoided. It should also be noted that a lot of
61+
parallel communication can be avoided. For the spectral data, the example
62+
above builds :cpp:`cMultiFab` using :cpp:`FFT::R2C` provided layout. You can
63+
also use your own :cpp:`BoxArray` and :cpp:`DistributionMapping`, but it
64+
might result in extra communication. It should also be noted that a lot of
6065
preparation works are done in the construction of an :cpp:`FFT::R2C`
61-
object. Therefore, one should cache it for reuse if possible.
66+
object. Therefore, one should cache it for reuse if possible. Although
67+
:cpp:`FFT::R2C` does not have a default constructor, one could always use
68+
:cpp:`std::unique_ptr<FFT::R2C<Real>>` to store an object in one's class.
6269

6370

6471
Poisson Solver
6572
==============
6673

67-
AMReX provides FFT based Poisson solvers. :cpp:`FFT::Poisson` supports all
68-
periodic boundaries using purely FFT. :cpp:`FFT::PoissonHybrid` is a 3D only
69-
solver that supports periodic boundaries in the first two dimensions and
70-
Neumann boundary in the last dimension. Similar to :cpp:`FFT::R2C`, the
71-
Poisson solvers should be cached for reuse.
74+
AMReX provides FFT based Poisson solvers. :cpp:`FFT::Poisson` supports
75+
periodic (:cpp:`FFT::Boundary::periodic`), homogeneous Neumann
76+
(:cpp:`FFT::Boundary::even`), and homogeneous Dirichlet
77+
(:cpp:`FFT::Boundary::odd`) boundaries using FFT. Below is an example of
78+
using the solver.
79+
80+
.. highlight:: c++
81+
82+
::
83+
84+
Geometry geom(...);
85+
MultiFab soln(...);
86+
MultiFab rhs(...);
87+
88+
Array<std::pair<FFT::Boundary,FFT::Boundary>,AMREX_SPACEDIM>
89+
fft_bc{...};
90+
91+
bool has_dirichlet = false;
92+
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
93+
has_dirichlet = has_dirichlet ||
94+
fft_bc[idim].first == FFT::Boundary::odd ||
95+
fft_bc[idim].second == FFT::Boundary::odd;
96+
}
97+
if (! has_dirichlet) {
98+
// Shift rhs so that its sum is zero.
99+
auto rhosum = rhs.sum(0);
100+
rhs.plus(-rhosum/geom.Domain().d_numPts(), 0, 1);
101+
}
102+
103+
FFT::Poisson fft_poisson(geom, fft_bc);
104+
fft_poisson.solve(soln, rhs);
105+
106+
:cpp:`FFT::PoissonHybrid` is a 3D only solver that supports periodic
107+
boundaries in the first two dimensions and Neumann boundary in the last
108+
dimension. The last dimension is solved with a tridiagonal solver that can
109+
support non-uniform cell size in the z-direction. For most applications,
110+
:cpp:`FFT::Poisson` should be used.
111+
112+
Similar to :cpp:`FFT::R2C`, the Poisson solvers should be cached for reuse,
113+
and one might need to use :cpp:`std::unique_ptr<FFT::Poisson<MultiFab>>`
114+
because there is no default constructor.

Src/Base/AMReX_GpuDevice.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535
#include <roctracer/roctx.h>
3636
#endif
3737
#endif
38+
#if defined(AMREX_USE_FFT)
39+
# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
40+
# include <rocfft/rocfft.h>
41+
# else
42+
# include <rocfft.h>
43+
# endif
44+
#endif
3845
#endif
3946

4047
#ifdef AMREX_USE_ACC
@@ -310,6 +317,10 @@ Device::Initialize ()
310317
}
311318
#endif /* AMREX_USE_MPI */
312319

320+
#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT)
321+
AMREX_ROCFFT_SAFE_CALL(rocfft_setup());
322+
#endif
323+
313324
if (amrex::Verbose()) {
314325
#if defined(AMREX_USE_CUDA)
315326
amrex::Print() << "CUDA"
@@ -349,6 +360,10 @@ Device::Finalize ()
349360
#ifdef AMREX_USE_GPU
350361
Device::profilerStop();
351362

363+
#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT)
364+
AMREX_ROCFFT_SAFE_CALL(rocfft_cleanup());
365+
#endif
366+
352367
#ifdef AMREX_USE_SYCL
353368
for (auto& s : gpu_stream_pool) {
354369
delete s.queue;

0 commit comments

Comments
 (0)