3434
3535#include < algorithm>
3636#include < complex>
37+ #include < limits>
3738#include < memory>
3839#include < utility>
3940#include < variant>
@@ -43,6 +44,8 @@ namespace amrex::FFT
4344
4445enum struct Direction { forward, backward, both, none };
4546
47+ enum struct DomainStrategy { slab, pencil };
48+
4649AMREX_ENUM ( Boundary, periodic, even, odd );
4750
4851enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
@@ -55,7 +58,11 @@ struct Info
5558 // ! batch size.
5659 bool batch_mode = false ;
5760
61+ // ! Max number of processes to use
62+ int nprocs = std::numeric_limits<int >::max();
63+
5864 Info& setBatchMode (bool x) { batch_mode = x; return *this ; }
65+ Info& setNumProcs (int n) { nprocs = n; return *this ; }
5966};
6067
6168#ifdef AMREX_USE_HIP
@@ -172,18 +179,34 @@ struct Plan
172179 }
173180
174181 template <Direction D>
175- void init_r2c (Box const & box, T* pr, VendorComplex* pc)
182+ void init_r2c (Box const & box, T* pr, VendorComplex* pc, bool is_2d_transform = false )
176183 {
177184 static_assert (D == Direction::forward || D == Direction::backward);
178185
186+ int rank = is_2d_transform ? 2 : 1 ;
187+
179188 kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
180189 defined = true ;
181190 pf = (void *)pr;
182191 pb = (void *)pc;
183192
184- n = box.length (0 );
185- int nc = (n/2 ) + 1 ;
186- howmany = AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ));
193+ int len[2 ] = {};
194+ if (rank == 1 ) {
195+ len[0 ] = box.length (0 );
196+ len[1 ] = box.length (0 ); // Not used except for HIP. Yes it's `(0)`.
197+ } else {
198+ len[0 ] = box.length (1 ); // Most FFT libraries assume row-major ordering
199+ len[1 ] = box.length (0 ); // except for rocfft
200+ }
201+ int nr = (rank == 1 ) ? len[0 ] : len[0 ]*len[1 ];
202+ n = nr;
203+ int nc = (rank == 1 ) ? (len[0 ]/2 +1 ) : (len[1 ]/2 +1 )*len[0 ];
204+ #if (AMREX_SPACEDIM == 1)
205+ howmany = 1 ;
206+ #else
207+ howmany = (rank == 1 ) ? AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ))
208+ : AMREX_D_TERM (1 , *1 , *box.length (2 ));
209+ #endif
187210
188211 amrex::ignore_unused (nc);
189212
@@ -193,43 +216,52 @@ struct Plan
193216 if constexpr (D == Direction::forward) {
194217 cufftType fwd_type = std::is_same_v<float ,T> ? CUFFT_R2C : CUFFT_D2Z;
195218 AMREX_CUFFT_SAFE_CALL
196- (cufftMakePlanMany (plan, 1 , &n , nullptr , 1 , n , nullptr , 1 , nc, fwd_type, howmany, &work_size));
219+ (cufftMakePlanMany (plan, rank, len , nullptr , 1 , nr , nullptr , 1 , nc, fwd_type, howmany, &work_size));
197220 AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
198221 } else {
199222 cufftType bwd_type = std::is_same_v<float ,T> ? CUFFT_C2R : CUFFT_Z2D;
200223 AMREX_CUFFT_SAFE_CALL
201- (cufftMakePlanMany (plan, 1 , &n , nullptr , 1 , nc, nullptr , 1 , n , bwd_type, howmany, &work_size));
224+ (cufftMakePlanMany (plan, rank, len , nullptr , 1 , nc, nullptr , 1 , nr , bwd_type, howmany, &work_size));
202225 AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
203226 }
204227#elif defined(AMREX_USE_HIP)
205228
206229 auto prec = std::is_same_v<float ,T> ? rocfft_precision_single : rocfft_precision_double;
207- const std::size_t length = n;
230+ // switch to column-major ordering
231+ std::size_t length[2 ] = {std::size_t (len[1 ]), std::size_t (len[0 ])};
208232 if constexpr (D == Direction::forward) {
209233 AMREX_ROCFFT_SAFE_CALL
210234 (rocfft_plan_create (&plan, rocfft_placement_notinplace,
211- rocfft_transform_type_real_forward, prec, 1 ,
212- & length, howmany, nullptr ));
235+ rocfft_transform_type_real_forward, prec, rank ,
236+ length, howmany, nullptr ));
213237 } else {
214238 AMREX_ROCFFT_SAFE_CALL
215239 (rocfft_plan_create (&plan, rocfft_placement_notinplace,
216- rocfft_transform_type_real_inverse, prec, 1 ,
217- & length, howmany, nullptr ));
240+ rocfft_transform_type_real_inverse, prec, rank ,
241+ length, howmany, nullptr ));
218242 }
219243
220244#elif defined(AMREX_USE_SYCL)
221245
222- auto * pp = new mkl_desc_r (n);
246+ mkl_desc_r* pp;
247+ if (rank == 1 ) {
248+ pp = new mkl_desc_r (len[0 ]);
249+ } else {
250+ pp = new mkl_desc_r ({std::int64_t (len[0 ]), std::int64_t (len[1 ])});
251+ }
223252#ifndef AMREX_USE_MKL_DFTI_2024
224253 pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT,
225254 oneapi::mkl::dft::config_value::NOT_INPLACE);
226255#else
227256 pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
228257#endif
229258 pp->set_value (oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
230- pp->set_value (oneapi::mkl::dft::config_param::FWD_DISTANCE, n );
259+ pp->set_value (oneapi::mkl::dft::config_param::FWD_DISTANCE, nr );
231260 pp->set_value (oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
232- std::vector<std::int64_t > strides = {0 ,1 };
261+ std::vector<std::int64_t > strides;
262+ strides.push_back (0 );
263+ if (rank == 2 ) { strides.push_back (len[1 ]); }
264+ strides.push_back (1 );
233265#ifndef AMREX_USE_MKL_DFTI_2024
234266 pp->set_value (oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
235267 pp->set_value (oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
@@ -247,21 +279,21 @@ struct Plan
247279 if constexpr (std::is_same_v<float ,T>) {
248280 if constexpr (D == Direction::forward) {
249281 plan = fftwf_plan_many_dft_r2c
250- (1 , &n , howmany, pr, nullptr , 1 , n , pc, nullptr , 1 , nc,
282+ (rank, len , howmany, pr, nullptr , 1 , nr , pc, nullptr , 1 , nc,
251283 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
252284 } else {
253285 plan = fftwf_plan_many_dft_c2r
254- (1 , &n , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , n ,
286+ (rank, len , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , nr ,
255287 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
256288 }
257289 } else {
258290 if constexpr (D == Direction::forward) {
259291 plan = fftw_plan_many_dft_r2c
260- (1 , &n , howmany, pr, nullptr , 1 , n , pc, nullptr , 1 , nc,
292+ (rank, len , howmany, pr, nullptr , 1 , nr , pc, nullptr , 1 , nc,
261293 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
262294 } else {
263295 plan = fftw_plan_many_dft_c2r
264- (1 , &n , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , n ,
296+ (rank, len , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , nr ,
265297 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
266298 }
267299 }
@@ -1087,13 +1119,17 @@ namespace detail
10871119 template <typename FA1, typename FA2>
10881120 std::unique_ptr<char ,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
10891121 {
1122+ bool not_same_fa = true ;
1123+ if constexpr (std::is_same_v<FA1,FA2>) {
1124+ not_same_fa = (&fa1 != &fa2);
1125+ }
10901126 using FAB1 = typename FA1::FABType::value_type;
10911127 using FAB2 = typename FA2::FABType::value_type;
10921128 using T1 = typename FAB1::value_type;
10931129 using T2 = typename FAB2::value_type;
10941130 auto myproc = ParallelContext::MyProcSub ();
10951131 bool alloc_1 = (myproc < fa1.size ());
1096- bool alloc_2 = (myproc < fa2.size ());
1132+ bool alloc_2 = (myproc < fa2.size ()) && not_same_fa ;
10971133 void * p = nullptr ;
10981134 if (alloc_1 && alloc_2) {
10991135 Box const & box1 = fa1.fabbox (myproc);
0 commit comments