Skip to content

Commit c59faf3

Browse files
committed
FFT::PoissonHybrid: Add interfaces for user provided dz
1 parent 2aa22d1 commit c59faf3

File tree

2 files changed

+210
-127
lines changed

2 files changed

+210
-127
lines changed

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ template <typename MF = MultiFab>
9191
class PoissonHybrid
9292
{
9393
public:
94+
using T = typename MF::value_type;
9495

9596
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
9697
explicit PoissonHybrid (Geometry const& geom)
@@ -104,6 +105,11 @@ public:
104105
}
105106

106107
void solve (MF& soln, MF const& rhs);
108+
void solve (MF& soln, MF const& rhs, Vector<T> const& dz);
109+
void solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz);
110+
111+
template <typename DZ>
112+
void solve_doit (MF& soln, MF const& rhs, DZ const& dz); // has to be public for cuda
107113

108114
private:
109115
Geometry m_geom;
@@ -225,14 +231,44 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
225231

226232
template <typename MF>
227233
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
234+
{
235+
auto delz = T(m_geom.CellSize(AMREX_SPACEDIM-1));
236+
struct DZ {
237+
[[nodiscard]] constexpr T operator[] (int) const { return m_delz; }
238+
T m_delz;
239+
};
240+
solve_doit(soln, rhs, DZ{delz});
241+
}
242+
243+
template <typename MF>
244+
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz)
245+
{
246+
auto const* pdz = dz.dataPtr();
247+
solve_doit(soln, rhs, pdz);
248+
}
249+
250+
template <typename MF>
251+
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
252+
{
253+
#ifdef AMREX_USE_GPU
254+
Gpu::DeviceVector<T> d_dz(dz.size());
255+
Gpu::htod_memcpy_async(d_dz.data(), dz.data(), dz.size()*sizeof(T));
256+
auto const* pdz = d_dz.data();
257+
#else
258+
auto const* pdz = dz.data();
259+
#endif
260+
solve_doit(soln, rhs, pdz);
261+
}
262+
263+
template <typename MF>
264+
template <typename DZ>
265+
void PoissonHybrid<MF>::solve_doit (MF& soln, MF const& rhs, DZ const& dz)
228266
{
229267
BL_PROFILE("FFT::PoissonHybrid::solve");
230268

231269
#if (AMREX_SPACEDIM < 3)
232-
amrex::ignore_unused(soln, rhs);
270+
amrex::ignore_unused(soln, rhs, dz);
233271
#else
234-
using T = typename MF::value_type;
235-
236272
auto facx = T(2)*Math::pi<T>()/T(m_geom.ProbLength(0));
237273
auto facy = T(2)*Math::pi<T>()/T(m_geom.ProbLength(1));
238274
auto dx = T(m_geom.CellSize(0));
@@ -242,9 +278,6 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
242278
auto ny = m_geom.Domain().length(1);
243279
auto nz = m_geom.Domain().length(2);
244280

245-
Gpu::DeviceVector<T> delzv(nz, T(m_geom.CellSize(2)));
246-
auto const* delz = delzv.data();
247-
248281
Box cdomain = m_geom.Domain();
249282
cdomain.setBig(0,cdomain.length(0)/2);
250283
auto cba = amrex::decompose(cdomain, ParallelContext::NProcsSub(),
@@ -283,18 +316,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
283316
for(int k=0; k < nz; k++) {
284317
if(k==0) {
285318
ald(i,j,k) = 0.;
286-
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
319+
cud(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
287320
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
288321
} else if (k == nz-1) {
289-
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
322+
ald(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
290323
cud(i,j,k) = 0.;
291324
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
292325
if (i == 0 && j == 0) {
293326
bd(i,j,k) *= 2.0;
294327
}
295328
} else {
296-
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
297-
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
329+
ald(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
330+
cud(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
298331
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
299332
}
300333
}
@@ -339,18 +372,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
339372
for(int k=0; k < nz; k++) {
340373
if(k==0) {
341374
ald[k] = 0.;
342-
cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
375+
cud[k] = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
343376
bd[k] = k2 -ald[k]-cud[k];
344377
} else if (k == nz-1) {
345-
ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
378+
ald[k] = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
346379
cud[k] = 0.;
347380
bd[k] = k2 -ald[k]-cud[k];
348381
if (i == 0 && j == 0) {
349382
bd[k] *= 2.0;
350383
}
351384
} else {
352-
ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
353-
cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
385+
ald[k] = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
386+
cud[k] = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
354387
bd[k] = k2 -ald[k]-cud[k];
355388
}
356389
}

0 commit comments

Comments
 (0)