Skip to content

Commit eec7d72

Browse files
committed
Fix CUDA cache issue
1 parent ec494bf commit eec7d72

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

Src/FFT/AMReX_FFT_Helper.H

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ struct Plan
145145
VendorPlan plan2{};
146146
void* pf = nullptr;
147147
void* pb = nullptr;
148-
#ifdef AMREX_USE_CUDA
149-
std::size_t work_size = 0;
150-
#endif
151148

152149
#ifdef AMREX_USE_GPU
153150
void set_ptrs (void* p0, void* p1) {
@@ -206,6 +203,7 @@ struct Plan
206203

207204
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
208205
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
206+
std::size_t work_size;
209207
if constexpr (D == Direction::forward) {
210208
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
211209
AMREX_CUFFT_SAFE_CALL
@@ -215,7 +213,6 @@ struct Plan
215213
AMREX_CUFFT_SAFE_CALL
216214
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
217215
}
218-
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
219216

220217
#elif defined(AMREX_USE_HIP)
221218

@@ -314,9 +311,9 @@ struct Plan
314311
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
315312

316313
cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
314+
std::size_t work_size;
317315
AMREX_CUFFT_SAFE_CALL
318316
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
319-
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
320317

321318
#elif defined(AMREX_USE_HIP)
322319

@@ -475,9 +472,9 @@ struct Plan
475472
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
476473
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
477474
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
475+
std::size_t work_size;
478476
AMREX_CUFFT_SAFE_CALL
479477
(cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
480-
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
481478

482479
#elif defined(AMREX_USE_HIP)
483480

@@ -585,8 +582,14 @@ struct Plan
585582
auto* po = (TO*)((D == Direction::forward) ? pb : pf);
586583

587584
#if defined(AMREX_USE_CUDA)
585+
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
586+
587+
std::size_t work_size = 0;
588+
AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
589+
588590
auto* work_area = The_Arena()->alloc(work_size);
589591
AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
592+
590593
if constexpr (D == Direction::forward) {
591594
if constexpr (std::is_same_v<float,T>) {
592595
AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
@@ -625,8 +628,14 @@ struct Plan
625628
auto* p = (VendorComplex*)pf;
626629

627630
#if defined(AMREX_USE_CUDA)
631+
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
632+
633+
std::size_t work_size = 0;
634+
AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
635+
628636
auto* work_area = The_Arena()->alloc(work_size);
629637
AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
638+
630639
auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
631640
if constexpr (std::is_same_v<float,T>) {
632641
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
@@ -1061,8 +1070,14 @@ struct Plan
10611070

10621071
#if defined(AMREX_USE_CUDA)
10631072

1073+
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1074+
1075+
std::size_t work_size = 0;
1076+
AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1077+
10641078
auto* work_area = The_Arena()->alloc(work_size);
10651079
AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1080+
10661081
if constexpr (std::is_same_v<float,T>) {
10671082
AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
10681083
} else {
@@ -1165,6 +1180,7 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* buffer, bool cache)
11651180
} else {
11661181
type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
11671182
}
1183+
std::size_t work_size;
11681184
if constexpr (M == 1) {
11691185
AMREX_CUFFT_SAFE_CALL
11701186
(cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
@@ -1175,7 +1191,6 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* buffer, bool cache)
11751191
AMREX_CUFFT_SAFE_CALL
11761192
(cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
11771193
}
1178-
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
11791194

11801195
#elif defined(AMREX_USE_HIP)
11811196

Src/FFT/AMReX_FFT_LocalR2C.H

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private:
8787

8888
void* m_buffer = nullptr; // for fftw
8989

90-
#if defined(AMREX_USE_GPU)
90+
#if defined(AMREX_USE_SYCL)
9191
gpuStream_t m_gpu_stream;
9292
#endif
9393

@@ -106,16 +106,20 @@ LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, bool cache_plan)
106106
BL_PROFILE("FFT::LocalR2C");
107107
m_spectral_size[0] = m_real_size[0]/2 + 1;
108108

109-
#if defined(AMREX_USE_GPU)
109+
#if defined(AMREX_USE_SYCL)
110+
110111
auto current_stream = Gpu::gpuStream();
111112
Gpu::Device::resetStreamIndex();
112113
m_gpu_stream = Gpu::gpuStream();
113-
#else
114+
115+
#elif !defined(AMREX_USE_GPU)
116+
114117
Long total_size = 1;
115118
for (auto s : m_spectral_size) {
116119
total_size *= s;
117120
}
118121
m_buffer = The_Arena()->alloc(sizeof(GpuComplex<T>)*total_size);
122+
119123
#endif
120124

121125
#ifdef AMREX_USE_SYCL
@@ -130,7 +134,7 @@ LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, bool cache_plan)
130134
}
131135
#endif
132136

133-
#if defined(AMREX_USE_GPU)
137+
#if defined(AMREX_USE_SYCL)
134138
Gpu::Device::setStream(current_stream);
135139
#endif
136140
}
@@ -158,13 +162,19 @@ void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
158162
BL_PROFILE("FFT::LocalR2C::forward");
159163

160164
#if defined(AMREX_USE_GPU)
165+
161166
m_fft_fwd.set_ptrs((void*)indata, (void*)outdata);
167+
168+
#if defined(AMREX_USE_SYCL)
162169
auto current_stream = Gpu::gpuStream();
163170
if (current_stream != m_gpu_stream) {
164171
Gpu::streamSynchronize();
165172
Gpu::Device::setStream(m_gpu_stream);
166173
}
167-
#else
174+
#endif
175+
176+
#else /* FFTW */
177+
168178
int ny = 1;
169179
for (int idim = 1; idim < M; ++idim) {
170180
ny *= m_real_size[idim];
@@ -176,18 +186,23 @@ void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
176186
pdst += sizeof(GpuComplex<T>) * m_spectral_size[0];
177187
psrc += sizeof( T ) * m_real_size[0];
178188
}
189+
179190
#endif
180191

181192
m_fft_fwd.template compute_r2c<Direction::forward>();
182193

183-
#if defined(AMREX_USE_GPU)
194+
#if defined(AMREX_USE_SYCL)
195+
184196
if (current_stream != m_gpu_stream) {
185197
Gpu::Device::setStream(current_stream);
186198
}
187-
#else
199+
200+
#elif !defined(AMREX_USE_GPU)
201+
188202
std::size_t nbytes = sizeof(GpuComplex<T>);
189203
for (auto s : m_spectral_size) { nbytes *= s; }
190204
std::memcpy(outdata, m_buffer, nbytes);
205+
191206
#endif
192207
}
193208

@@ -199,25 +214,35 @@ void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
199214
BL_PROFILE("FFT::LocalR2C::backward");
200215

201216
#if defined(AMREX_USE_GPU)
217+
202218
m_fft_bwd.set_ptrs((void*)outdata, (void*)indata);
219+
220+
#if defined(AMREX_USE_SYCL)
203221
auto current_stream = Gpu::gpuStream();
204222
if (current_stream != m_gpu_stream) {
205223
Gpu::streamSynchronize();
206224
Gpu::Device::setStream(m_gpu_stream);
207225
}
208-
#else
226+
#endif
227+
228+
#else /* FFTW */
229+
209230
std::size_t nbytes = sizeof(GpuComplex<T>);
210231
for (auto s : m_spectral_size) { nbytes *= s; }
211232
std::memcpy(m_buffer, indata, nbytes);
233+
212234
#endif
213235

214236
m_fft_bwd.template compute_r2c<Direction::backward>();
215237

216-
#if defined(AMREX_USE_GPU)
238+
#if defined(AMREX_USE_SYCL)
239+
217240
if (current_stream != m_gpu_stream) {
218241
Gpu::Device::setStream(current_stream);
219242
}
220-
#else
243+
244+
#elif !defined(AMREX_USE_GPU)
245+
221246
int ny = 1;
222247
for (int idim = 1; idim < M; ++idim) {
223248
ny *= m_real_size[idim];
@@ -229,6 +254,7 @@ void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
229254
pdst += sizeof( T ) * m_real_size[0];
230255
psrc += sizeof(GpuComplex<T>) * m_spectral_size[0];
231256
}
257+
232258
#endif
233259
}
234260

0 commit comments

Comments
 (0)