Skip to content

Commit 80cf43c

Browse files
committed
FFT::PoissonHybrid: Add interfaces for user provided dz
1 parent 2aa22d1 commit 80cf43c

File tree

2 files changed

+214
-127
lines changed

2 files changed

+214
-127
lines changed

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ template <typename MF = MultiFab>
9191
class PoissonHybrid
9292
{
9393
public:
94+
using T = typename MF::value_type;
9495

9596
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
9697
explicit PoissonHybrid (Geometry const& geom)
@@ -104,6 +105,11 @@ public:
104105
}
105106

106107
void solve (MF& soln, MF const& rhs);
108+
void solve (MF& soln, MF const& rhs, Vector<T> const& dz);
109+
void solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz);
110+
111+
template <typename DZ>
112+
void solve_doit (MF& soln, MF const& rhs, DZ const& dz); // has to be public for cuda
107113

108114
private:
109115
Geometry m_geom;
@@ -223,16 +229,50 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
223229

224230
#endif /* AMREX_SPACEDIM == 3 */
225231

232+
namespace fft_poisson_detail {
233+
template <typename T>
234+
struct DZ {
235+
[[nodiscard]] constexpr T operator[] (int) const { return m_delz; }
236+
T m_delz;
237+
};
238+
}
239+
226240
template <typename MF>
227241
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
242+
{
243+
auto delz = T(m_geom.CellSize(AMREX_SPACEDIM-1));
244+
solve_doit(soln, rhs, fft_poisson_detail::DZ<T>{delz});
245+
}
246+
247+
template <typename MF>
248+
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz)
249+
{
250+
auto const* pdz = dz.dataPtr();
251+
solve_doit(soln, rhs, pdz);
252+
}
253+
254+
template <typename MF>
255+
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
256+
{
257+
#ifdef AMREX_USE_GPU
258+
Gpu::DeviceVector<T> d_dz(dz.size());
259+
Gpu::htod_memcpy_async(d_dz.data(), dz.data(), dz.size()*sizeof(T));
260+
auto const* pdz = d_dz.data();
261+
#else
262+
auto const* pdz = dz.data();
263+
#endif
264+
solve_doit(soln, rhs, pdz);
265+
}
266+
267+
template <typename MF>
268+
template <typename DZ>
269+
void PoissonHybrid<MF>::solve_doit (MF& soln, MF const& rhs, DZ const& dz)
228270
{
229271
BL_PROFILE("FFT::PoissonHybrid::solve");
230272

231273
#if (AMREX_SPACEDIM < 3)
232-
amrex::ignore_unused(soln, rhs);
274+
amrex::ignore_unused(soln, rhs, dz);
233275
#else
234-
using T = typename MF::value_type;
235-
236276
auto facx = T(2)*Math::pi<T>()/T(m_geom.ProbLength(0));
237277
auto facy = T(2)*Math::pi<T>()/T(m_geom.ProbLength(1));
238278
auto dx = T(m_geom.CellSize(0));
@@ -242,9 +282,6 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
242282
auto ny = m_geom.Domain().length(1);
243283
auto nz = m_geom.Domain().length(2);
244284

245-
Gpu::DeviceVector<T> delzv(nz, T(m_geom.CellSize(2)));
246-
auto const* delz = delzv.data();
247-
248285
Box cdomain = m_geom.Domain();
249286
cdomain.setBig(0,cdomain.length(0)/2);
250287
auto cba = amrex::decompose(cdomain, ParallelContext::NProcsSub(),
@@ -283,18 +320,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
283320
for(int k=0; k < nz; k++) {
284321
if(k==0) {
285322
ald(i,j,k) = 0.;
286-
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
323+
cud(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
287324
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
288325
} else if (k == nz-1) {
289-
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
326+
ald(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
290327
cud(i,j,k) = 0.;
291328
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
292329
if (i == 0 && j == 0) {
293330
bd(i,j,k) *= 2.0;
294331
}
295332
} else {
296-
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
297-
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
333+
ald(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
334+
cud(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
298335
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
299336
}
300337
}
@@ -339,18 +376,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
339376
for(int k=0; k < nz; k++) {
340377
if(k==0) {
341378
ald[k] = 0.;
342-
cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
379+
cud[k] = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
343380
bd[k] = k2 -ald[k]-cud[k];
344381
} else if (k == nz-1) {
345-
ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
382+
ald[k] = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
346383
cud[k] = 0.;
347384
bd[k] = k2 -ald[k]-cud[k];
348385
if (i == 0 && j == 0) {
349386
bd[k] *= 2.0;
350387
}
351388
} else {
352-
ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
353-
cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
389+
ald[k] = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
390+
cud[k] = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
354391
bd[k] = k2 -ald[k]-cud[k];
355392
}
356393
}

0 commit comments

Comments
 (0)