@@ -145,9 +145,6 @@ struct Plan
145145 VendorPlan plan2{};
146146 void * pf = nullptr ;
147147 void * pb = nullptr ;
148- #ifdef AMREX_USE_CUDA
149- std::size_t work_size = 0 ;
150- #endif
151148
152149#ifdef AMREX_USE_GPU
153150 void set_ptrs (void * p0, void * p1) {
@@ -206,6 +203,7 @@ struct Plan
206203
207204 AMREX_CUFFT_SAFE_CALL (cufftCreate (&plan));
208205 AMREX_CUFFT_SAFE_CALL (cufftSetAutoAllocation (plan, 0 ));
206+ std::size_t work_size;
209207 if constexpr (D == Direction::forward) {
210208 cufftType fwd_type = std::is_same_v<float ,T> ? CUFFT_R2C : CUFFT_D2Z;
211209 AMREX_CUFFT_SAFE_CALL
@@ -215,7 +213,6 @@ struct Plan
215213 AMREX_CUFFT_SAFE_CALL
216214 (cufftMakePlanMany (plan, rank, len, nullptr , 1 , nc, nullptr , 1 , nr, bwd_type, howmany, &work_size));
217215 }
218- AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
219216
220217#elif defined(AMREX_USE_HIP)
221218
@@ -314,9 +311,9 @@ struct Plan
314311 AMREX_CUFFT_SAFE_CALL (cufftSetAutoAllocation (plan, 0 ));
315312
316313 cufftType t = std::is_same_v<float ,T> ? CUFFT_C2C : CUFFT_Z2Z;
314+ std::size_t work_size;
317315 AMREX_CUFFT_SAFE_CALL
318316 (cufftMakePlanMany (plan, 1 , &n, nullptr , 1 , n, nullptr , 1 , n, t, howmany, &work_size));
319- AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
320317
321318#elif defined(AMREX_USE_HIP)
322319
@@ -475,9 +472,9 @@ struct Plan
475472 AMREX_CUFFT_SAFE_CALL (cufftCreate (&plan));
476473 AMREX_CUFFT_SAFE_CALL (cufftSetAutoAllocation (plan, 0 ));
477474 cufftType fwd_type = std::is_same_v<float ,T> ? CUFFT_R2C : CUFFT_D2Z;
475+ std::size_t work_size;
478476 AMREX_CUFFT_SAFE_CALL
479477 (cufftMakePlanMany (plan, 1 , &nex, nullptr , 1 , nc*2 , nullptr , 1 , nc, fwd_type, howmany, &work_size));
480- AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
481478
482479#elif defined(AMREX_USE_HIP)
483480
@@ -585,8 +582,14 @@ struct Plan
585582 auto * po = (TO*)((D == Direction::forward) ? pb : pf);
586583
587584#if defined(AMREX_USE_CUDA)
585+ AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
586+
587+ std::size_t work_size = 0 ;
588+ AMREX_CUFFT_SAFE_CALL (cufftGetSize (plan, &work_size));
589+
588590 auto * work_area = The_Arena ()->alloc (work_size);
589591 AMREX_CUFFT_SAFE_CALL (cufftSetWorkArea (plan, work_area));
592+
590593 if constexpr (D == Direction::forward) {
591594 if constexpr (std::is_same_v<float ,T>) {
592595 AMREX_CUFFT_SAFE_CALL (cufftExecR2C (plan, pi, po));
@@ -625,8 +628,14 @@ struct Plan
625628 auto * p = (VendorComplex*)pf;
626629
627630#if defined(AMREX_USE_CUDA)
631+ AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
632+
633+ std::size_t work_size = 0 ;
634+ AMREX_CUFFT_SAFE_CALL (cufftGetSize (plan, &work_size));
635+
628636 auto * work_area = The_Arena ()->alloc (work_size);
629637 AMREX_CUFFT_SAFE_CALL (cufftSetWorkArea (plan, work_area));
638+
630639 auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
631640 if constexpr (std::is_same_v<float ,T>) {
632641 AMREX_CUFFT_SAFE_CALL (cufftExecC2C (plan, p, p, dir));
@@ -1061,8 +1070,14 @@ struct Plan
10611070
10621071#if defined(AMREX_USE_CUDA)
10631072
1073+ AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
1074+
1075+ std::size_t work_size = 0 ;
1076+ AMREX_CUFFT_SAFE_CALL (cufftGetSize (plan, &work_size));
1077+
10641078 auto * work_area = The_Arena ()->alloc (work_size);
10651079 AMREX_CUFFT_SAFE_CALL (cufftSetWorkArea (plan, work_area));
1080+
10661081 if constexpr (std::is_same_v<float ,T>) {
10671082 AMREX_CUFFT_SAFE_CALL (cufftExecR2C (plan, (T*)pscratch, (VendorComplex*)pscratch));
10681083 } else {
@@ -1165,6 +1180,7 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* buffer, bool cache)
11651180 } else {
11661181 type = std::is_same_v<float ,T> ? CUFFT_C2R : CUFFT_Z2D;
11671182 }
1183+ std::size_t work_size;
11681184 if constexpr (M == 1 ) {
11691185 AMREX_CUFFT_SAFE_CALL
11701186 (cufftMakePlan1d (plan, fft_size[0 ], type, howmany, &work_size));
@@ -1175,7 +1191,6 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* buffer, bool cache)
11751191 AMREX_CUFFT_SAFE_CALL
11761192 (cufftMakePlan3d (plan, fft_size[2 ], fft_size[1 ], fft_size[0 ], type, &work_size));
11771193 }
1178- AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
11791194
11801195#elif defined(AMREX_USE_HIP)
11811196
0 commit comments