@@ -43,6 +43,8 @@ namespace amrex::FFT
4343
4444enum struct Direction { forward, backward, both, none };
4545
46+ enum struct DomainStrategy { slab, pencil };
47+
4648AMREX_ENUM ( Boundary, periodic, even, odd );
4749
4850enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
@@ -172,18 +174,29 @@ struct Plan
172174 }
173175
174176 template <Direction D>
175- void init_r2c (Box const & box, T* pr, VendorComplex* pc)
177+ void init_r2c (Box const & box, T* pr, VendorComplex* pc, bool is_2d_transform = false )
176178 {
177179 static_assert (D == Direction::forward || D == Direction::backward);
178180
181+ int rank = is_2d_transform ? 2 : 1 ;
182+
179183 kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
180184 defined = true ;
181185 pf = (void *)pr;
182186 pb = (void *)pc;
183187
184- n = box.length (0 );
185- int nc = (n/2 ) + 1 ;
186- howmany = AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ));
188+ int len[2 ] = {};
189+ if (rank == 1 ) {
190+ len[0 ] = box.length (0 );
191+ } else {
192+ len[0 ] = box.length (1 );
193+ len[1 ] = box.length (0 );
194+ }
195+ int nr = (rank == 1 ) ? len[0 ] : len[0 ]*len[1 ];
196+ n = nr;
197+ int nc = (rank == 1 ) ? (len[0 ]/2 +1 ) : (len[1 ]/2 +1 )*len[0 ];
198+ howmany = (rank == 1 ) ? AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ))
199+ : AMREX_D_TERM (1 , *1 , *box.length (2 ));
187200
188201 amrex::ignore_unused (nc);
189202
@@ -193,43 +206,51 @@ struct Plan
193206 if constexpr (D == Direction::forward) {
194207 cufftType fwd_type = std::is_same_v<float ,T> ? CUFFT_R2C : CUFFT_D2Z;
195208 AMREX_CUFFT_SAFE_CALL
196- (cufftMakePlanMany (plan, 1 , &n , nullptr , 1 , n , nullptr , 1 , nc, fwd_type, howmany, &work_size));
209+ (cufftMakePlanMany (plan, rank, len , nullptr , 1 , nr , nullptr , 1 , nc, fwd_type, howmany, &work_size));
197210 AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
198211 } else {
199212 cufftType bwd_type = std::is_same_v<float ,T> ? CUFFT_C2R : CUFFT_Z2D;
200213 AMREX_CUFFT_SAFE_CALL
201- (cufftMakePlanMany (plan, 1 , &n , nullptr , 1 , nc, nullptr , 1 , n , bwd_type, howmany, &work_size));
214+ (cufftMakePlanMany (plan, rank, len , nullptr , 1 , nc, nullptr , 1 , nr , bwd_type, howmany, &work_size));
202215 AMREX_CUFFT_SAFE_CALL (cufftSetStream (plan, Gpu::gpuStream ()));
203216 }
204217#elif defined(AMREX_USE_HIP)
205218
206219 auto prec = std::is_same_v<float ,T> ? rocfft_precision_single : rocfft_precision_double;
207- const std::size_t length = n ;
220+ const std::size_t length[ 2 ] = { std::size_t (len[ 0 ]), std::size_t (len[ 1 ])} ;
208221 if constexpr (D == Direction::forward) {
209222 AMREX_ROCFFT_SAFE_CALL
210223 (rocfft_plan_create (&plan, rocfft_placement_notinplace,
211- rocfft_transform_type_real_forward, prec, 1 ,
212- & length, howmany, nullptr ));
224+ rocfft_transform_type_real_forward, prec, rank ,
225+ length, howmany, nullptr ));
213226 } else {
214227 AMREX_ROCFFT_SAFE_CALL
215228 (rocfft_plan_create (&plan, rocfft_placement_notinplace,
216- rocfft_transform_type_real_inverse, prec, 1 ,
217- & length, howmany, nullptr ));
229+ rocfft_transform_type_real_inverse, prec, rank ,
230+ length, howmany, nullptr ));
218231 }
219232
220233#elif defined(AMREX_USE_SYCL)
221234
222- auto * pp = new mkl_desc_r (n);
235+ mkl_desc_c* pp;
236+ if (rank == 1 ) {
237+ pp = new mkl_desc_r (len[0 ]);
238+ } else {
239+ pp = new mkl_desc_r ({std::int64_t (len[0 ]), std::int64_t (len[1 ])});
240+ }
223241#ifndef AMREX_USE_MKL_DFTI_2024
224242 pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT,
225243 oneapi::mkl::dft::config_value::NOT_INPLACE);
226244#else
227245 pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
228246#endif
229247 pp->set_value (oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
230- pp->set_value (oneapi::mkl::dft::config_param::FWD_DISTANCE, n );
248+ pp->set_value (oneapi::mkl::dft::config_param::FWD_DISTANCE, nr );
231249 pp->set_value (oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
232- std::vector<std::int64_t > strides = {0 ,1 };
250+ std::vector<std::int64_t > strides;
251+ strides.push_back (0 );
252+ if (rank == 2 ) { strides.push_back (len[1 ]); }
253+ rank.push_back (1 );
233254#ifndef AMREX_USE_MKL_DFTI_2024
234255 pp->set_value (oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
235256 pp->set_value (oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
@@ -247,21 +268,21 @@ struct Plan
247268 if constexpr (std::is_same_v<float ,T>) {
248269 if constexpr (D == Direction::forward) {
249270 plan = fftwf_plan_many_dft_r2c
250- (1 , &n , howmany, pr, nullptr , 1 , n , pc, nullptr , 1 , nc,
271+ (rank, len , howmany, pr, nullptr , 1 , nr , pc, nullptr , 1 , nc,
251272 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
252273 } else {
253274 plan = fftwf_plan_many_dft_c2r
254- (1 , &n , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , n ,
275+ (rank, len , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , nr ,
255276 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
256277 }
257278 } else {
258279 if constexpr (D == Direction::forward) {
259280 plan = fftw_plan_many_dft_r2c
260- (1 , &n , howmany, pr, nullptr , 1 , n , pc, nullptr , 1 , nc,
281+ (rank, len , howmany, pr, nullptr , 1 , nr , pc, nullptr , 1 , nc,
261282 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
262283 } else {
263284 plan = fftw_plan_many_dft_c2r
264- (1 , &n , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , n ,
285+ (rank, len , howmany, pc, nullptr , 1 , nc, pr, nullptr , 1 , nr ,
265286 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
266287 }
267288 }
@@ -1087,13 +1108,17 @@ namespace detail
10871108 template <typename FA1, typename FA2>
10881109 std::unique_ptr<char ,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
10891110 {
1111+ bool not_same_fa = true ;
1112+ if constexpr (std::is_same_v<FA1,FA2>) {
1113+ not_same_fa = (&fa1 != &fa2);
1114+ }
10901115 using FAB1 = typename FA1::FABType::value_type;
10911116 using FAB2 = typename FA2::FABType::value_type;
10921117 using T1 = typename FAB1::value_type;
10931118 using T2 = typename FAB2::value_type;
10941119 auto myproc = ParallelContext::MyProcSub ();
10951120 bool alloc_1 = (myproc < fa1.size ());
1096- bool alloc_2 = (myproc < fa2.size ());
1121+ bool alloc_2 = (myproc < fa2.size ()) && not_same_fa ;
10971122 void * p = nullptr ;
10981123 if (alloc_1 && alloc_2) {
10991124 Box const & box1 = fa1.fabbox (myproc);
0 commit comments