Skip to content

Commit e26cf17

Browse files
committed
Add FFT Poisson solvers
1 parent e40a9ec commit e26cf17

File tree

7 files changed

+367
-91
lines changed

7 files changed

+367
-91
lines changed

Src/FFT/AMReX_FFT.H

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,19 @@ public:
4747
template <typename F>
4848
void forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward)
4949
{
50-
this->forward_doit(inmf);
50+
this->forward(inmf);
5151
this->post_forward_doit(post_forward);
52-
this->backward_doit(outmf);
52+
this->backward(outmf);
5353
}
5454

55+
void forward (MF const& inmf, Scaling scaling = Scaling::none);
56+
void forward (MF const& inmf, cMF& outmf, Scaling scaling = Scaling::none);
57+
58+
void backward (MF& outmf, Scaling scaling = Scaling::none);
59+
void backward (cMF const& inmf, MF& outmf, Scaling scaling = Scaling::none);
60+
61+
std::pair<cMF*,IntVect> getSpectralData ();
62+
5563
struct Swap01
5664
{
5765
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept
@@ -153,9 +161,6 @@ private:
153161
}
154162
}
155163

156-
void forward_doit (MF const& inmf, Scaling scaling = Scaling::none);
157-
void backward_doit (MF& outmf, Scaling scaling = Scaling::none);
158-
159164
static void exec_r2c (Plan plan, MF& in, cMF& out);
160165
static void exec_c2r (Plan plan, cMF& in, MF& out);
161166
template <Direction direction>
@@ -175,10 +180,10 @@ private:
175180
// Comm meta-data. In the forward phase, we start with (x,y,z),
176181
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
177182
// perform inverse transpose.
178-
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y;
179-
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x;
180-
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z;
181-
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y;
183+
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y; // (x,y,z) -> (y,x,z)
184+
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x; // (y,x,z) -> (x,y,z)
185+
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z; // (y,x,z) -> (z,x,y)
186+
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y; // (z,x,y) -> (y,x,z)
182187
Swap01 m_dtos_x2y{};
183188
Swap01 m_dtos_y2x{};
184189
Swap02 m_dtos_y2z{};
@@ -232,12 +237,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
232237
int nprocs = ParallelDescriptor::NProcs();
233238

234239
auto bax = amrex::decompose(m_real_domain, nprocs, {AMREX_D_DECL(false,true,true)});
235-
DistributionMapping dmx;
236-
{
237-
Vector<int> pm(bax.size());
238-
std::iota(pm.begin(), pm.end(), 0);
239-
dmx.define(std::move(pm));
240-
}
240+
DistributionMapping dmx = detail::make_iota_distromap(bax.size());
241241
m_rx.define(bax, dmx, 1, 0);
242242

243243
{
@@ -346,9 +346,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
346346
if (cbay.size() == dmx.size()) {
347347
cdmy = dmx;
348348
} else {
349-
Vector<int> pm(cbay.size());
350-
std::iota(pm.begin(), pm.end(), 0);
351-
cdmy.define(std::move(pm));
349+
cdmy = detail::make_iota_distromap(cbay.size());
352350
}
353351
m_cy.define(cbay, cdmy, 1, 0);
354352

@@ -365,7 +363,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
365363

366364
#if (AMREX_SPACEDIM == 3)
367365
if (m_real_domain.length(1) > 1 &&
368-
(! m_info.batch_mode || m_real_domain.length(2) > 1))
366+
(! m_info.batch_mode && m_real_domain.length(2) > 1))
369367
{
370368
auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true});
371369
DistributionMapping cdmz;
@@ -374,9 +372,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
374372
} else if (cbaz.size() == cdmy.size()) {
375373
cdmz = cdmy;
376374
} else {
377-
Vector<int> pm(cbaz.size());
378-
std::iota(pm.begin(), pm.end(), 0);
379-
cdmz.define(std::move(pm));
375+
cdmz = detail::make_iota_distromap(cbaz.size());
380376
}
381377
m_cz.define(cbaz, cdmz, 1, 0);
382378

@@ -563,8 +559,10 @@ void R2C<T>::exec_c2c (Plan2 plan, cMF& inout)
563559
}
564560

565561
template <typename T>
566-
void R2C<T>::forward_doit (MF const& inmf, Scaling /*scaling*/)
562+
void R2C<T>::forward (MF const& inmf, Scaling scaling)
567563
{
564+
AMREX_ALWAYS_ASSERT(scaling == Scaling::none); // xxxxx TODO
565+
568566
m_rx.ParallelCopy(inmf, 0, 0, 1);
569567
exec_r2c(m_fft_fwd_x, m_rx, m_cx);
570568

@@ -580,8 +578,10 @@ void R2C<T>::forward_doit (MF const& inmf, Scaling /*scaling*/)
580578
}
581579

582580
template <typename T>
583-
void R2C<T>::backward_doit (MF& outmf, Scaling /*scaling*/)
581+
void R2C<T>::backward (MF& outmf, Scaling scaling)
584582
{
583+
AMREX_ALWAYS_ASSERT(scaling == Scaling::none); // xxxxx TODO
584+
585585
exec_c2c<Direction::backward>(m_fft_bwd_z, m_cz);
586586
if ( m_cmd_z2y) {
587587
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);
@@ -716,6 +716,51 @@ void R2C<T>::post_forward_doit (F const& post_forward)
716716
}
717717
}
718718

719+
template <typename T>
720+
std::pair<typename R2C<T>::cMF *, IntVect>
721+
R2C<T>::getSpectralData ()
722+
{
723+
if (!m_cz.empty()) {
724+
return std::make_pair(&m_cz, IntVect{AMREX_D_DECL(2,0,1)});
725+
} else if (!m_cy.empty()) {
726+
return std::make_pair(&m_cy, IntVect{AMREX_D_DECL(1,0,2)});
727+
} else {
728+
return std::make_pair(&m_cx, IntVect{AMREX_D_DECL(0,1,2)});
729+
}
730+
}
731+
732+
template <typename T>
733+
void R2C<T>::forward (MF const& inmf, cMF& outmf, Scaling scaling)
734+
{
735+
forward(inmf);
736+
if (!m_cz.empty()) { // m_cz's ordering is z,x,y
737+
amrex::Abort("xxxxx todo, forward m_cz");
738+
} else if (!m_cy.empty()) { // m_cy's order (y,x,z) -> (x,y,z)
739+
MultiBlockCommMetaData cmd
740+
(outmf.boxArray(), outmf.DistributionMap(), m_spectral_domain_x,
741+
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x);
742+
ParallelCopy(outmf, m_cy, cmd, 0, 0, 1, m_dtos_y2x);
743+
} else {
744+
outmf.ParallelCopy(m_cx, 0, 0, 1);
745+
}
746+
}
747+
748+
template <typename T>
749+
void R2C<T>::backward (cMF const& inmf, MF& outmf, Scaling scaling)
750+
{
751+
if (!m_cz.empty()) { // m_cz's ordering is z,x,y
752+
amrex::Abort("xxxxx todo, backward m_cz");
753+
} else if (!m_cy.empty()) { // (x,y,z) -> m_cy's ordering (y,x,z)
754+
MultiBlockCommMetaData cmd
755+
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y,
756+
inmf.boxArray(), inmf.DistributionMap(), IntVect(0), m_dtos_x2y);
757+
ParallelCopy(m_cy, inmf, cmd, 0, 0, 1, m_dtos_x2y);
758+
} else {
759+
m_cx.ParallelCopy(inmf, 0, 0, 1);
760+
}
761+
backward(outmf);
762+
}
763+
719764
}
720765

721766
#endif

Src/FFT/AMReX_FFT.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
#include <AMReX_FFT.H>
2+
#include <algorithm>
23

3-
namespace amrex::FFT
4+
namespace amrex::FFT::detail
45
{
56

6-
#ifdef AMREX_USE_HIP
7-
namespace detail
7+
DistributionMapping make_iota_distromap (Long n)
88
{
9+
AMREX_ASSERT(n <= ParallelDescriptor::NProcs());
10+
Vector<int> pm(n);
11+
std::iota(pm.begin(), pm.end(), 0);
12+
return DistributionMapping(std::move(pm));
13+
}
14+
15+
#ifdef AMREX_USE_HIP
916
void hip_execute (rocfft_plan plan, void **in, void **out)
1017
{
1118
rocfft_execution_info execinfo = nullptr;
@@ -26,7 +33,6 @@ void hip_execute (rocfft_plan plan, void **in, void **out)
2633

2734
AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_destroy(execinfo));
2835
}
29-
}
3036
#endif
3137

3238
}

Src/FFT/AMReX_FFT_Helper.H

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22
#define AMREX_FFT_HELPER_H_
33
#include <AMReX_Config.H>
44

5-
#include <AMReX.H>
6-
#include <AMReX_Geometry.H>
7-
#include <AMReX_Gpu.H>
8-
#include <AMReX_GpuComplex.H>
9-
#include <AMReX_Math.H>
5+
#include <AMReX_DistributionMapping.H>
106

117
namespace amrex::FFT
128
{
@@ -24,49 +20,10 @@ struct Info
2420
Info& setBatchMode (bool x) { batch_mode = x; return *this; }
2521
};
2622

27-
template <typename T>
28-
struct PoissonSpectral
23+
namespace detail
2924
{
30-
PoissonSpectral (Geometry const& geom)
31-
: fac({AMREX_D_DECL(T(2)*Math::pi<T>()/T(geom.ProbLength(0)),
32-
T(2)*Math::pi<T>()/T(geom.ProbLength(1)),
33-
T(2)*Math::pi<T>()/T(geom.ProbLength(2)))}),
34-
dx({AMREX_D_DECL(T(geom.CellSize(0)),
35-
T(geom.CellSize(1)),
36-
T(geom.CellSize(2)))}),
37-
scale(T(1.0/geom.Domain().d_numPts())),
38-
len(geom.Domain().length())
39-
{
40-
static_assert(std::is_floating_point_v<T>);
41-
}
42-
43-
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
44-
void operator() (int i, int j, int k, GpuComplex<T>& spectral_data) const
45-
{
46-
amrex::ignore_unused(i,j,k);
47-
// the values in the upper-half of the spectral array in y and z
48-
// are here interpreted as negative wavenumbers
49-
AMREX_D_TERM(T a = fac[0]*i;,
50-
T b = (j < len[1]/2) ? fac[1]*j : fac[1]*(len[1]-j);,
51-
T c = (k < len[2]/2) ? fac[2]*k : fac[2]*(len[2]-k));
52-
T k2 = AMREX_D_TERM(T(2)*(std::cos(a*dx[0])-T(1))/(dx[0]*dx[0]),
53-
+T(2)*(std::cos(b*dx[1])-T(1))/(dx[1]*dx[1]),
54-
+T(2)*(std::cos(c*dx[2])-T(1))/(dx[2]*dx[2]));
55-
if (k2 != T(0)) {
56-
spectral_data /= k2;
57-
} else {
58-
// interpretation here is that the average value of the
59-
// solution is zero
60-
spectral_data = 0;
61-
}
62-
spectral_data *= scale;
63-
}
64-
65-
GpuArray<T,AMREX_SPACEDIM> fac;
66-
GpuArray<T,AMREX_SPACEDIM> dx;
67-
T scale;
68-
IntVect len;
69-
};
25+
DistributionMapping make_iota_distromap (Long n);
26+
}
7027

7128
}
7229

0 commit comments

Comments
 (0)