Skip to content

Commit a51e406

Browse files
committed
FFT: Add new domain decomposition strategy
Instead of pencil, it has the option of doing slab decomposition. This allows the x and y directions to be processed together without MPI communication.
1 parent 8e7bb00 commit a51e406

File tree

4 files changed

+153
-71
lines changed

4 files changed

+153
-71
lines changed

Src/FFT/AMReX_FFT_Helper.H

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ namespace amrex::FFT
4343

4444
enum struct Direction { forward, backward, both, none };
4545

46+
enum struct DomainStrategy { slab, pencil };
47+
4648
AMREX_ENUM( Boundary, periodic, even, odd );
4749

4850
enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
@@ -172,18 +174,29 @@ struct Plan
172174
}
173175

174176
template <Direction D>
175-
void init_r2c (Box const& box, T* pr, VendorComplex* pc)
177+
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
176178
{
177179
static_assert(D == Direction::forward || D == Direction::backward);
178180

181+
int rank = is_2d_transform ? 2 : 1;
182+
179183
kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
180184
defined = true;
181185
pf = (void*)pr;
182186
pb = (void*)pc;
183187

184-
n = box.length(0);
185-
int nc = (n/2) + 1;
186-
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
188+
int len[2] = {};
189+
if (rank == 1) {
190+
len[0] = box.length(0);
191+
} else {
192+
len[0] = box.length(1);
193+
len[1] = box.length(0);
194+
}
195+
int nr = (rank == 1) ? len[0] : len[0]*len[1];
196+
n = nr;
197+
int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
198+
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
199+
: AMREX_D_TERM(1, *1 , *box.length(2));
187200

188201
amrex::ignore_unused(nc);
189202

@@ -193,43 +206,51 @@ struct Plan
193206
if constexpr (D == Direction::forward) {
194207
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
195208
AMREX_CUFFT_SAFE_CALL
196-
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, nc, fwd_type, howmany, &work_size));
209+
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
197210
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
198211
} else {
199212
cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
200213
AMREX_CUFFT_SAFE_CALL
201-
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, nc, nullptr, 1, n, bwd_type, howmany, &work_size));
214+
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
202215
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
203216
}
204217
#elif defined(AMREX_USE_HIP)
205218

206219
auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
207-
const std::size_t length = n;
220+
const std::size_t length[2] = {std::size_t(len[0]), std::size_t(len[1])};
208221
if constexpr (D == Direction::forward) {
209222
AMREX_ROCFFT_SAFE_CALL
210223
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
211-
rocfft_transform_type_real_forward, prec, 1,
212-
&length, howmany, nullptr));
224+
rocfft_transform_type_real_forward, prec, rank,
225+
length, howmany, nullptr));
213226
} else {
214227
AMREX_ROCFFT_SAFE_CALL
215228
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
216-
rocfft_transform_type_real_inverse, prec, 1,
217-
&length, howmany, nullptr));
229+
rocfft_transform_type_real_inverse, prec, rank,
230+
length, howmany, nullptr));
218231
}
219232

220233
#elif defined(AMREX_USE_SYCL)
221234

222-
auto* pp = new mkl_desc_r(n);
235+
mkl_desc_c* pp;
236+
if (rank == 1) {
237+
pp = new mkl_desc_r(len[0]);
238+
} else {
239+
pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
240+
}
223241
#ifndef AMREX_USE_MKL_DFTI_2024
224242
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
225243
oneapi::mkl::dft::config_value::NOT_INPLACE);
226244
#else
227245
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
228246
#endif
229247
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
230-
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
248+
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
231249
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
232-
std::vector<std::int64_t> strides = {0,1};
250+
std::vector<std::int64_t> strides;
251+
strides.push_back(0);
252+
if (rank == 2) { strides.push_back(len[1]); }
253+
rank.push_back(1);
233254
#ifndef AMREX_USE_MKL_DFTI_2024
234255
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
235256
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
@@ -247,21 +268,21 @@ struct Plan
247268
if constexpr (std::is_same_v<float,T>) {
248269
if constexpr (D == Direction::forward) {
249270
plan = fftwf_plan_many_dft_r2c
250-
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
271+
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
251272
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
252273
} else {
253274
plan = fftwf_plan_many_dft_c2r
254-
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
275+
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
255276
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
256277
}
257278
} else {
258279
if constexpr (D == Direction::forward) {
259280
plan = fftw_plan_many_dft_r2c
260-
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
281+
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
261282
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
262283
} else {
263284
plan = fftw_plan_many_dft_c2r
264-
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
285+
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
265286
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
266287
}
267288
}
@@ -1087,13 +1108,17 @@ namespace detail
10871108
template <typename FA1, typename FA2>
10881109
std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
10891110
{
1111+
bool not_same_fa = true;
1112+
if constexpr (std::is_same_v<FA1,FA2>) {
1113+
not_same_fa = (&fa1 != &fa2);
1114+
}
10901115
using FAB1 = typename FA1::FABType::value_type;
10911116
using FAB2 = typename FA2::FABType::value_type;
10921117
using T1 = typename FAB1::value_type;
10931118
using T2 = typename FAB2::value_type;
10941119
auto myproc = ParallelContext::MyProcSub();
10951120
bool alloc_1 = (myproc < fa1.size());
1096-
bool alloc_2 = (myproc < fa2.size());
1121+
bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
10971122
void* p = nullptr;
10981123
if (alloc_1 && alloc_2) {
10991124
Box const& box1 = fa1.fabbox(myproc);

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,42 @@ public:
1919
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
2020
Poisson (Geometry const& geom,
2121
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> const& bc)
22-
: m_geom(geom), m_bc(bc), m_r2x(geom.Domain(),bc)
23-
{}
22+
: m_geom(geom), m_bc(bc)
23+
{
24+
bool all_periodic = true;
25+
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
26+
all_periodic = all_periodic
27+
&& (bc[idim].first == Boundary::periodic)
28+
&& (bc[idim].second == Boundary::periodic);
29+
}
30+
if (m_geom.isAllPeriodic()) {
31+
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain());
32+
} else {
33+
m_r2x = std::make_unique<R2X<typename MF::value_type>> (m_geom.Domain(), m_bc);
34+
}
35+
}
2436

2537
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
2638
explicit Poisson (Geometry const& geom)
2739
: m_geom(geom),
2840
m_bc{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic),
2941
std::make_pair(Boundary::periodic,Boundary::periodic),
30-
std::make_pair(Boundary::periodic,Boundary::periodic))},
31-
m_r2x(geom.Domain(),m_bc)
42+
std::make_pair(Boundary::periodic,Boundary::periodic))}
3243
{
33-
AMREX_ALWAYS_ASSERT(m_geom.isAllPeriodic());
44+
if (m_geom.isAllPeriodic()) {
45+
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain());
46+
} else {
47+
amrex::Abort("FFT::Poisson: wrong BC");
48+
}
3449
}
3550

3651
void solve (MF& soln, MF const& rhs);
3752

3853
private:
3954
Geometry m_geom;
4055
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> m_bc;
41-
R2X<typename MF::value_type> m_r2x;
56+
std::unique_ptr<R2X<typename MF::value_type>> m_r2x;
57+
std::unique_ptr<R2C<typename MF::value_type>> m_r2c;
4258
};
4359

4460
#if (AMREX_SPACEDIM == 3)
@@ -114,7 +130,7 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
114130
{AMREX_D_DECL(T(2)/T(m_geom.CellSize(0)*m_geom.CellSize(0)),
115131
T(2)/T(m_geom.CellSize(1)*m_geom.CellSize(1)),
116132
T(2)/T(m_geom.CellSize(2)*m_geom.CellSize(2)))};
117-
auto scale = m_r2x.scalingFactor();
133+
auto scale = (m_r2x) ? m_r2x->scalingFactor() : T(1)/T(m_geom.numPts());
118134

119135
GpuArray<T,AMREX_SPACEDIM> offset{AMREX_D_DECL(T(0),T(0),T(0))};
120136
// Not sure about odd-even and even-odd yet
@@ -133,8 +149,7 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
133149
}
134150
}
135151

136-
m_r2x.forwardThenBackward(rhs, soln,
137-
[=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data)
152+
auto f [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data)
138153
{
139154
amrex::ignore_unused(j,k);
140155
AMREX_D_TERM(T a = fac[0]*(i+offset[0]);,
@@ -147,7 +162,13 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
147162
spectral_data /= k2;
148163
}
149164
spectral_data *= scale;
150-
});
165+
};
166+
167+
if (m_r2x) {
168+
m_r2x->forwardThenBackward(rhs, soln, f);
169+
} else {
170+
m_r2c->forwardThenBackward(rhs, soln, f);
171+
}
151172
}
152173

153174
#if (AMREX_SPACEDIM == 3)

0 commit comments

Comments
 (0)