@@ -91,6 +91,7 @@ template <typename MF = MultiFab>
9191class PoissonHybrid
9292{
9393public:
94+ using T = typename MF::value_type;
9495
9596 template <typename FA=MF, std::enable_if_t <IsFabArray_v<FA>,int > = 0 >
9697 explicit PoissonHybrid (Geometry const & geom)
@@ -104,6 +105,11 @@ public:
104105 }
105106
106107 void solve (MF& soln, MF const & rhs);
108+ void solve (MF& soln, MF const & rhs, Vector<T> const & dz);
109+ void solve (MF& soln, MF const & rhs, Gpu::DeviceVector<T> const & dz);
110+
111+ template <typename DZ>
112+ void solve_doit (MF& soln, MF const & rhs, DZ const & dz); // has to be public for cuda
107113
108114private:
109115 Geometry m_geom;
@@ -223,16 +229,50 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
223229
224230#endif /* AMREX_SPACEDIM == 3 */
225231
232+ namespace fft_poisson_detail {
233+ template <typename T>
234+ struct DZ {
235+ [[nodiscard]] constexpr T operator [] (int ) const { return m_delz; }
236+ T m_delz;
237+ };
238+ }
239+
226240template <typename MF>
227241void PoissonHybrid<MF>::solve (MF& soln, MF const & rhs)
242+ {
243+ auto delz = T (m_geom.CellSize (AMREX_SPACEDIM-1 ));
244+ solve_doit (soln, rhs, fft_poisson_detail::DZ<T>{delz});
245+ }
246+
247+ template <typename MF>
248+ void PoissonHybrid<MF>::solve (MF& soln, MF const & rhs, Gpu::DeviceVector<T> const & dz)
249+ {
250+ auto const * pdz = dz.dataPtr ();
251+ solve_doit (soln, rhs, pdz);
252+ }
253+
254+ template <typename MF>
255+ void PoissonHybrid<MF>::solve (MF& soln, MF const & rhs, Vector<T> const & dz)
256+ {
257+ #ifdef AMREX_USE_GPU
258+ Gpu::DeviceVector<T> d_dz (dz.size ());
259+ Gpu::htod_memcpy_async (d_dz.data (), dz.data (), dz.size ()*sizeof (T));
260+ auto const * pdz = d_dz.data ();
261+ #else
262+ auto const * pdz = dz.data ();
263+ #endif
264+ solve_doit (soln, rhs, pdz);
265+ }
266+
267+ template <typename MF>
268+ template <typename DZ>
269+ void PoissonHybrid<MF>::solve_doit (MF& soln, MF const & rhs, DZ const & dz)
228270{
229271 BL_PROFILE (" FFT::PoissonHybrid::solve" );
230272
231273#if (AMREX_SPACEDIM < 3)
232- amrex::ignore_unused (soln, rhs);
274+ amrex::ignore_unused (soln, rhs, dz );
233275#else
234- using T = typename MF::value_type;
235-
236276 auto facx = T (2 )*Math::pi<T>()/T (m_geom.ProbLength (0 ));
237277 auto facy = T (2 )*Math::pi<T>()/T (m_geom.ProbLength (1 ));
238278 auto dx = T (m_geom.CellSize (0 ));
@@ -242,9 +282,6 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
242282 auto ny = m_geom.Domain ().length (1 );
243283 auto nz = m_geom.Domain ().length (2 );
244284
245- Gpu::DeviceVector<T> delzv (nz, T (m_geom.CellSize (2 )));
246- auto const * delz = delzv.data ();
247-
248285 Box cdomain = m_geom.Domain ();
249286 cdomain.setBig (0 ,cdomain.length (0 )/2 );
250287 auto cba = amrex::decompose (cdomain, ParallelContext::NProcsSub (),
@@ -283,18 +320,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
283320 for (int k=0 ; k < nz; k++) {
284321 if (k==0 ) {
285322 ald (i,j,k) = 0 .;
286- cud (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
323+ cud (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
287324 bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
288325 } else if (k == nz-1 ) {
289- ald (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
326+ ald (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
290327 cud (i,j,k) = 0 .;
291328 bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
292329 if (i == 0 && j == 0 ) {
293330 bd (i,j,k) *= 2.0 ;
294331 }
295332 } else {
296- ald (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
297- cud (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
333+ ald (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
334+ cud (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
298335 bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
299336 }
300337 }
@@ -339,18 +376,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
339376 for (int k=0 ; k < nz; k++) {
340377 if (k==0 ) {
341378 ald[k] = 0 .;
342- cud[k] = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
379+ cud[k] = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
343380 bd[k] = k2 -ald[k]-cud[k];
344381 } else if (k == nz-1 ) {
345- ald[k] = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
382+ ald[k] = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
346383 cud[k] = 0 .;
347384 bd[k] = k2 -ald[k]-cud[k];
348385 if (i == 0 && j == 0 ) {
349386 bd[k] *= 2.0 ;
350387 }
351388 } else {
352- ald[k] = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
353- cud[k] = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
389+ ald[k] = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
390+ cud[k] = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
354391 bd[k] = k2 -ald[k]-cud[k];
355392 }
356393 }
0 commit comments