Skip to content

Commit bca9c1b

Browse files
Apply suggestions from code review
Co-authored-by: Alexander Sinn <[email protected]>
1 parent a2c2201 commit bca9c1b

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

Src/Base/AMReX_GpuLaunch.H

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ namespace Gpu {
180180

181181
struct ExecConfig
182182
{
183-
Long ntotalthreads;
183+
Long start_idx;
184184
int nblocks;
185185
};
186186

@@ -195,6 +195,7 @@ namespace Gpu {
195195
// loops inside GPU kernels.
196196
auto nlaunches = int((N+nmax-1)/nmax);
197197
Vector<ExecConfig> r(nlaunches);
198+
Long ndone = 0;
198199
for (int i = 0; i < nlaunches; ++i) {
199200
int nblocks;
200201
if (N > nmax) {
@@ -203,8 +204,9 @@ namespace Gpu {
203204
} else {
204205
nblocks = int((N+MT-1)/MT);
205206
}
206-
// Total # of threads in this launch
207-
r[i].ntotalthreads = Long(nblocks) * MT;
207+
// At which element ID the kernel should start
208+
r[i].start_idx = ndone;
209+
ndone += Long(nblocks) * MT;
208210
// # of blocks in this launch
209211
r[i].nblocks = nblocks;
210212
}

Src/Base/AMReX_GpuLaunchFunctsG.H

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -800,22 +800,19 @@ ParallelFor (Gpu::KernelInfo const&, T n, L const& f) noexcept
800800
static_assert(sizeof(T) >= 2);
801801
if (amrex::isEmpty(n)) { return; }
802802
const auto& nec = Gpu::makeNExecutionConfigs<MT>(n);
803-
T ndone = 0;
804803
for (auto const& ec : nec) {
805-
T nleft = n - ndone;
804+
const T start_idx = T(ec.start_idx);
805+
const T nleft = n - start_idx;
806806
AMREX_LAUNCH_KERNEL(MT, ec.nblocks, MT, 0, Gpu::gpuStream(),
807807
[=] AMREX_GPU_DEVICE () noexcept {
808808
// This will not overflow, even though nblocks*MT might.
809809
auto tid = T(MT)*T(blockIdx.x)+T(threadIdx.x);
810810
if (tid < nleft) {
811-
detail::call_f_scalar_handler(f, tid+ndone,
811+
detail::call_f_scalar_handler(f, tid+start_idx,
812812
Gpu::Handler(amrex::min((std::uint64_t(nleft-tid)+(std::uint64_t)threadIdx.x),
813813
(std::uint64_t)blockDim.x)));
814814
}
815815
});
816-
if (Long(nleft) > ec.ntotalthreads) {
817-
ndone += ec.ntotalthreads;
818-
}
819816
}
820817
AMREX_GPU_ERROR_CHECK();
821818
}

0 commit comments

Comments
 (0)