@@ -91,6 +91,7 @@ template <typename MF = MultiFab>
9191class PoissonHybrid
9292{
9393public:
94+ using T = typename MF::value_type;
9495
9596 template <typename FA=MF, std::enable_if_t <IsFabArray_v<FA>,int > = 0 >
9697 explicit PoissonHybrid (Geometry const & geom)
@@ -104,6 +105,11 @@ public:
104105 }
105106
106107 void solve (MF& soln, MF const & rhs);
108+ void solve (MF& soln, MF const & rhs, Vector<T> const & dz);
109+ void solve (MF& soln, MF const & rhs, Gpu::DeviceVector<T> const & dz);
110+
111+ template <typename DZ>
112+ void solve_doit (MF& soln, MF const & rhs, DZ const & dz); // has to be public for cuda
107113
108114private:
109115 Geometry m_geom;
@@ -225,14 +231,44 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
225231
226232template <typename MF>
227233void PoissonHybrid<MF>::solve (MF& soln, MF const & rhs)
234+ {
235+ auto delz = T (m_geom.CellSize (AMREX_SPACEDIM-1 ));
236+ struct DZ {
237+ [[nodiscard]] constexpr T operator [] (int ) const { return m_delz; }
238+ T m_delz;
239+ };
240+ solve_doit (soln, rhs, DZ{delz});
241+ }
242+
243+ template <typename MF>
244+ void PoissonHybrid<MF>::solve (MF& soln, MF const & rhs, Gpu::DeviceVector<T> const & dz)
245+ {
246+ auto const * pdz = dz.dataPtr ();
247+ solve_doit (soln, rhs, pdz);
248+ }
249+
250+ template <typename MF>
251+ void PoissonHybrid<MF>::solve (MF& soln, MF const & rhs, Vector<T> const & dz)
252+ {
253+ #ifdef AMREX_USE_GPU
254+ Gpu::DeviceVector<T> d_dz (dz.size ());
255+ Gpu::htod_memcpy_async (d_dz.data (), dz.data (), dz.size ()*sizeof (T));
256+ auto const * pdz = d_dz.data ();
257+ #else
258+ auto const * pdz = dz.data ();
259+ #endif
260+ solve_doit (soln, rhs, pdz);
261+ }
262+
263+ template <typename MF>
264+ template <typename DZ>
265+ void PoissonHybrid<MF>::solve_doit (MF& soln, MF const & rhs, DZ const & dz)
228266{
229267 BL_PROFILE (" FFT::PoissonHybrid::solve" );
230268
231269#if (AMREX_SPACEDIM < 3)
232- amrex::ignore_unused (soln, rhs);
270+ amrex::ignore_unused (soln, rhs, dz );
233271#else
234- using T = typename MF::value_type;
235-
236272 auto facx = T (2 )*Math::pi<T>()/T (m_geom.ProbLength (0 ));
237273 auto facy = T (2 )*Math::pi<T>()/T (m_geom.ProbLength (1 ));
238274 auto dx = T (m_geom.CellSize (0 ));
@@ -242,9 +278,6 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
242278 auto ny = m_geom.Domain ().length (1 );
243279 auto nz = m_geom.Domain ().length (2 );
244280
245- Gpu::DeviceVector<T> delzv (nz, T (m_geom.CellSize (2 )));
246- auto const * delz = delzv.data ();
247-
248281 Box cdomain = m_geom.Domain ();
249282 cdomain.setBig (0 ,cdomain.length (0 )/2 );
250283 auto cba = amrex::decompose (cdomain, ParallelContext::NProcsSub (),
@@ -283,18 +316,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
283316 for (int k=0 ; k < nz; k++) {
284317 if (k==0 ) {
285318 ald (i,j,k) = 0 .;
286- cud (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
319+ cud (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
287320 bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
288321 } else if (k == nz-1 ) {
289- ald (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
322+ ald (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
290323 cud (i,j,k) = 0 .;
291324 bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
292325 if (i == 0 && j == 0 ) {
293326 bd (i,j,k) *= 2.0 ;
294327 }
295328 } else {
296- ald (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
297- cud (i,j,k) = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
329+ ald (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
330+ cud (i,j,k) = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
298331 bd (i,j,k) = k2 -ald (i,j,k)-cud (i,j,k);
299332 }
300333 }
@@ -339,18 +372,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
339372 for (int k=0 ; k < nz; k++) {
340373 if (k==0 ) {
341374 ald[k] = 0 .;
342- cud[k] = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
375+ cud[k] = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
343376 bd[k] = k2 -ald[k]-cud[k];
344377 } else if (k == nz-1 ) {
345- ald[k] = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
378+ ald[k] = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
346379 cud[k] = 0 .;
347380 bd[k] = k2 -ald[k]-cud[k];
348381 if (i == 0 && j == 0 ) {
349382 bd[k] *= 2.0 ;
350383 }
351384 } else {
352- ald[k] = 2.0 /(delz [k]*(delz [k]+delz [k-1 ]));
353- cud[k] = 2.0 /(delz [k]*(delz [k]+delz [k+1 ]));
385+ ald[k] = 2.0 /(dz [k]*(dz [k]+dz [k-1 ]));
386+ cud[k] = 2.0 /(dz [k]*(dz [k]+dz [k+1 ]));
354387 bd[k] = k2 -ald[k]-cud[k];
355388 }
356389 }
0 commit comments