Skip to content

Commit 57380f3

Browse files
authored
FFT: Add new domain decomposition strategy (#4221)
Instead of pencil, it has the option of doing slab decomposition. This allows the x and y directions to be processed together without MPI communication.
1 parent 8e7bb00 commit 57380f3

File tree

8 files changed

+297
-111
lines changed

8 files changed

+297
-111
lines changed

.github/workflows/apps.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ jobs:
114114
runs-on: ubuntu-latest
115115
needs: check_changes
116116
if: needs.check_changes.outputs.has_non_docs_changes == 'true'
117+
steps:
117118
- uses: actions/checkout@v4
118119
- name: Checkout pyamrex
119120
uses: actions/checkout@v4

.github/workflows/dependencies/dependencies_hip.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ echo 'export PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/opt/rocm/profiler/bin:/opt/r
4040

4141
# we should not need to export HIP_PATH=/opt/rocm/hip with those installs
4242

43+
sudo apt-get clean
4344
sudo apt-get update
4445

4546
# Ref.: https://rocmdocs.amd.com/en/latest/Installation_Guide/Installation-Guide.html#installing-development-packages-for-cross-compilation

Src/FFT/AMReX_FFT_Helper.H

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#include <algorithm>
3636
#include <complex>
37+
#include <limits>
3738
#include <memory>
3839
#include <utility>
3940
#include <variant>
@@ -43,6 +44,8 @@ namespace amrex::FFT
4344

4445
enum struct Direction { forward, backward, both, none };
4546

47+
enum struct DomainStrategy { slab, pencil };
48+
4649
AMREX_ENUM( Boundary, periodic, even, odd );
4750

4851
enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
@@ -55,7 +58,11 @@ struct Info
5558
//! batch size.
5659
bool batch_mode = false;
5760

61+
//! Max number of processes to use
62+
int nprocs = std::numeric_limits<int>::max();
63+
5864
Info& setBatchMode (bool x) { batch_mode = x; return *this; }
65+
Info& setNumProcs (int n) { nprocs = n; return *this; }
5966
};
6067

6168
#ifdef AMREX_USE_HIP
@@ -172,18 +179,34 @@ struct Plan
172179
}
173180

174181
template <Direction D>
175-
void init_r2c (Box const& box, T* pr, VendorComplex* pc)
182+
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
176183
{
177184
static_assert(D == Direction::forward || D == Direction::backward);
178185

186+
int rank = is_2d_transform ? 2 : 1;
187+
179188
kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
180189
defined = true;
181190
pf = (void*)pr;
182191
pb = (void*)pc;
183192

184-
n = box.length(0);
185-
int nc = (n/2) + 1;
186-
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
193+
int len[2] = {};
194+
if (rank == 1) {
195+
len[0] = box.length(0);
196+
len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
197+
} else {
198+
len[0] = box.length(1); // Most FFT libraries assume row-major ordering
199+
len[1] = box.length(0); // except for rocfft
200+
}
201+
int nr = (rank == 1) ? len[0] : len[0]*len[1];
202+
n = nr;
203+
int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
204+
#if (AMREX_SPACEDIM == 1)
205+
howmany = 1;
206+
#else
207+
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
208+
: AMREX_D_TERM(1, *1 , *box.length(2));
209+
#endif
187210

188211
amrex::ignore_unused(nc);
189212

@@ -193,43 +216,52 @@ struct Plan
193216
if constexpr (D == Direction::forward) {
194217
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
195218
AMREX_CUFFT_SAFE_CALL
196-
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, nc, fwd_type, howmany, &work_size));
219+
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
197220
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
198221
} else {
199222
cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
200223
AMREX_CUFFT_SAFE_CALL
201-
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, nc, nullptr, 1, n, bwd_type, howmany, &work_size));
224+
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
202225
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
203226
}
204227
#elif defined(AMREX_USE_HIP)
205228

206229
auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
207-
const std::size_t length = n;
230+
// switch to column-major ordering
231+
std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
208232
if constexpr (D == Direction::forward) {
209233
AMREX_ROCFFT_SAFE_CALL
210234
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
211-
rocfft_transform_type_real_forward, prec, 1,
212-
&length, howmany, nullptr));
235+
rocfft_transform_type_real_forward, prec, rank,
236+
length, howmany, nullptr));
213237
} else {
214238
AMREX_ROCFFT_SAFE_CALL
215239
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
216-
rocfft_transform_type_real_inverse, prec, 1,
217-
&length, howmany, nullptr));
240+
rocfft_transform_type_real_inverse, prec, rank,
241+
length, howmany, nullptr));
218242
}
219243

220244
#elif defined(AMREX_USE_SYCL)
221245

222-
auto* pp = new mkl_desc_r(n);
246+
mkl_desc_r* pp;
247+
if (rank == 1) {
248+
pp = new mkl_desc_r(len[0]);
249+
} else {
250+
pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
251+
}
223252
#ifndef AMREX_USE_MKL_DFTI_2024
224253
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
225254
oneapi::mkl::dft::config_value::NOT_INPLACE);
226255
#else
227256
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
228257
#endif
229258
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
230-
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
259+
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
231260
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
232-
std::vector<std::int64_t> strides = {0,1};
261+
std::vector<std::int64_t> strides;
262+
strides.push_back(0);
263+
if (rank == 2) { strides.push_back(len[1]); }
264+
strides.push_back(1);
233265
#ifndef AMREX_USE_MKL_DFTI_2024
234266
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
235267
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
@@ -247,21 +279,21 @@ struct Plan
247279
if constexpr (std::is_same_v<float,T>) {
248280
if constexpr (D == Direction::forward) {
249281
plan = fftwf_plan_many_dft_r2c
250-
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
282+
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
251283
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
252284
} else {
253285
plan = fftwf_plan_many_dft_c2r
254-
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
286+
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
255287
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
256288
}
257289
} else {
258290
if constexpr (D == Direction::forward) {
259291
plan = fftw_plan_many_dft_r2c
260-
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
292+
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
261293
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
262294
} else {
263295
plan = fftw_plan_many_dft_c2r
264-
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
296+
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
265297
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
266298
}
267299
}
@@ -1087,13 +1119,17 @@ namespace detail
10871119
template <typename FA1, typename FA2>
10881120
std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
10891121
{
1122+
bool not_same_fa = true;
1123+
if constexpr (std::is_same_v<FA1,FA2>) {
1124+
not_same_fa = (&fa1 != &fa2);
1125+
}
10901126
using FAB1 = typename FA1::FABType::value_type;
10911127
using FAB2 = typename FA2::FABType::value_type;
10921128
using T1 = typename FAB1::value_type;
10931129
using T2 = typename FAB2::value_type;
10941130
auto myproc = ParallelContext::MyProcSub();
10951131
bool alloc_1 = (myproc < fa1.size());
1096-
bool alloc_2 = (myproc < fa2.size());
1132+
bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
10971133
void* p = nullptr;
10981134
if (alloc_1 && alloc_2) {
10991135
Box const& box1 = fa1.fabbox(myproc);

Src/FFT/AMReX_FFT_OpenBCSolver.H

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public:
1515
using MF = typename R2C<T>::MF;
1616
using cMF = typename R2C<T>::cMF;
1717

18-
explicit OpenBCSolver (Box const& domain);
18+
explicit OpenBCSolver (Box const& domain, Info const& info = Info{});
1919

2020
template <class F>
2121
void setGreensFunction (F const& greens_function);
@@ -25,34 +25,70 @@ public:
2525
[[nodiscard]] Box const& Domain () const { return m_domain; }
2626

2727
private:
28+
static Box make_grown_domain (Box const& domain, Info const& info);
29+
2830
Box m_domain;
31+
Info m_info;
2932
R2C<T> m_r2c;
3033
cMF m_G_fft;
34+
std::unique_ptr<R2C<T>> m_r2c_green;
3135
};
3236

3337
template <typename T>
34-
OpenBCSolver<T>::OpenBCSolver (Box const& domain)
38+
Box OpenBCSolver<T>::make_grown_domain (Box const& domain, Info const& info)
39+
{
40+
IntVect len = domain.length();
41+
#if (AMREX_SPACEDIM == 3)
42+
if (info.batch_mode) { len[2] = 0; }
43+
#else
44+
amrex::ignore_unused(info);
45+
#endif
46+
return Box(domain.smallEnd(), domain.bigEnd()+len, domain.ixType());
47+
}
48+
49+
template <typename T>
50+
OpenBCSolver<T>::OpenBCSolver (Box const& domain, Info const& info)
3551
: m_domain(domain),
36-
m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType()))
52+
m_info(info),
53+
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info), info)
3754
{
38-
auto [sd, ord] = m_r2c.getSpectralData();
39-
amrex::ignore_unused(ord);
40-
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
55+
#if (AMREX_SPACEDIM == 3)
56+
if (m_info.batch_mode) {
57+
auto gdom = make_grown_domain(domain,m_info);
58+
gdom.enclosedCells(2);
59+
gdom.setSmall(2, 0);
60+
int nprocs = std::min({ParallelContext::NProcsSub(),
61+
m_info.nprocs,
62+
m_domain.length(2)});
63+
gdom.setBig(2, nprocs-1);
64+
m_r2c_green = std::make_unique<R2C<T>>(gdom,info);
65+
auto [sd, ord] = m_r2c_green->getSpectralData();
66+
m_G_fft = cMF(*sd, amrex::make_alias, 0, 1);
67+
} else
68+
#endif
69+
{
70+
amrex::ignore_unused(m_r2c_green);
71+
auto [sd, ord] = m_r2c.getSpectralData();
72+
amrex::ignore_unused(ord);
73+
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
74+
}
4175
}
4276

4377
template <typename T>
4478
template <class F>
4579
void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
4680
{
47-
auto* infab = detail::get_fab(m_r2c.m_rx);
81+
auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx)
82+
: detail::get_fab(m_r2c.m_rx);
4883
auto const& lo = m_domain.smallEnd();
4984
auto const& lo3 = lo.dim3();
5085
auto const& len = m_domain.length3d();
5186
if (infab) {
5287
auto const& a = infab->array();
5388
auto box = infab->box();
5489
GpuArray<int,3> nimages{1,1,1};
55-
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
90+
int ndims = m_info.batch_mode ? AMREX_SPACEDIM : AMREX_SPACEDIM-1;
91+
for (int idim = 0; idim < ndims; ++idim) {
5692
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
5793
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
5894
nimages[idim] = 2;
@@ -62,46 +98,59 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
6298
box.shift(-lo);
6399
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
64100
{
101+
T G;
65102
if (i == len[0] || j == len[1] || k == len[2]) {
66-
a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0);
103+
G = 0;
67104
} else {
68105
auto ii = i;
69106
auto jj = (j > len[1]) ? 2*len[1]-j : j;
70107
auto kk = (k > len[2]) ? 2*len[2]-k : k;
71-
auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
72-
for (int koff = 0; koff < nimages[2]; ++koff) {
73-
int k2 = (koff == 0) ? k : 2*len[2]-k;
74-
if (k2 == 2*len[2]) { continue; }
75-
for (int joff = 0; joff < nimages[1]; ++joff) {
76-
int j2 = (joff == 0) ? j : 2*len[1]-j;
77-
if (j2 == 2*len[1]) { continue; }
78-
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
79-
int i2 = (ioff == 0) ? i : 2*len[0]-i;
80-
if (i2 == 2*len[0]) { continue; }
81-
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
108+
G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
109+
}
110+
for (int koff = 0; koff < nimages[2]; ++koff) {
111+
int k2 = (koff == 0) ? k : 2*len[2]-k;
112+
if ((k2 == 2*len[2]) || (koff == 1 && k == len[2])) {
113+
continue;
114+
}
115+
for (int joff = 0; joff < nimages[1]; ++joff) {
116+
int j2 = (joff == 0) ? j : 2*len[1]-j;
117+
if ((j2 == 2*len[1]) || (joff == 1 && j == len[1])) {
118+
continue;
119+
}
120+
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
121+
int i2 = (ioff == 0) ? i : 2*len[0]-i;
122+
if ((i2 == 2*len[0]) || (ioff == 1 && i == len[0])) {
123+
continue;
82124
}
125+
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
83126
}
84127
}
85128
}
86129
});
87130
}
88131

89-
m_r2c.forward(m_r2c.m_rx);
132+
if (m_info.batch_mode) {
133+
m_r2c_green->forward(m_r2c_green->m_rx);
134+
} else {
135+
m_r2c.forward(m_r2c.m_rx);
136+
}
90137

91-
auto [sd, ord] = m_r2c.getSpectralData();
92-
amrex::ignore_unused(ord);
93-
auto const* srcfab = detail::get_fab(*sd);
94-
if (srcfab) {
95-
auto* dstfab = detail::get_fab(m_G_fft);
96-
if (dstfab) {
138+
if (!m_info.batch_mode) {
139+
auto [sd, ord] = m_r2c.getSpectralData();
140+
amrex::ignore_unused(ord);
141+
auto const* srcfab = detail::get_fab(*sd);
142+
if (srcfab) {
143+
auto* dstfab = detail::get_fab(m_G_fft);
144+
if (dstfab) {
97145
#if defined(AMREX_USE_GPU)
98-
Gpu::dtod_memcpy_async
146+
Gpu::dtod_memcpy_async
99147
#else
100-
std::memcpy
148+
std::memcpy
101149
#endif
102-
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
103-
} else {
104-
amrex::Abort("FFT::OpenBCSolver: how did this happen");
150+
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
151+
} else {
152+
amrex::Abort("FFT::OpenBCSolver: how did this happen");
153+
}
105154
}
106155
}
107156
}
@@ -115,7 +164,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
115164

116165
m_r2c.forward(inmf);
117166

118-
auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts());
167+
auto scaling_factor = m_r2c.scalingFactor();
119168

120169
auto const* gfab = detail::get_fab(m_G_fft);
121170
if (gfab) {
@@ -125,9 +174,24 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
125174
if (rhofab) {
126175
auto* pdst = rhofab->dataPtr();
127176
auto const* psrc = gfab->dataPtr();
128-
amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i)
177+
Box const& rhobox = rhofab->box();
178+
#if (AMREX_SPACEDIM == 3)
179+
Long leng = gfab->box().numPts();
180+
if (m_info.batch_mode) {
181+
AMREX_ASSERT(gfab->box().length(2) == 1 &&
182+
leng == (rhobox.length(0) * rhobox.length(1)));
183+
} else {
184+
AMREX_ASSERT(leng == rhobox.numPts());
185+
}
186+
#endif
187+
amrex::ParallelFor(rhobox.numPts(), [=] AMREX_GPU_DEVICE (Long i)
129188
{
130-
pdst[i] *= psrc[i] * scaling_factor;
189+
#if (AMREX_SPACEDIM == 3)
190+
Long isrc = i % leng;
191+
#else
192+
Long isrc = i;
193+
#endif
194+
pdst[i] *= psrc[isrc] * scaling_factor;
131195
});
132196
} else {
133197
amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?");

0 commit comments

Comments
 (0)