Skip to content

Commit efda45f

Browse files
committed
Refactor grid-stride loop
Move grid-stride loop out of GPU kernels. @ashesh2512 noticed performance issues with grid-stride loops on AMD GPUs in PelePhyscis's large kernels. Thank @AlexanderSinn for the suggestion implemented in this PR.
1 parent 62c2a81 commit efda45f

File tree

2 files changed

+127
-39
lines changed

2 files changed

+127
-39
lines changed

Src/Base/AMReX_GpuLaunch.H

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <AMReX_RandomEngine.H>
2222
#include <AMReX_Algorithm.H>
2323
#include <AMReX_Math.H>
24+
#include <AMReX_Vector.H>
2425
#include <cstddef>
2526
#include <limits>
2627
#include <algorithm>
@@ -176,6 +177,45 @@ namespace Gpu {
176177
{
177178
return makeExecutionConfig<MT>(box.numPts());
178179
}
180+
181+
struct ExecConfig
182+
{
183+
Long ntotalthreads;
184+
int nblocks;
185+
};
186+
187+
template <int MT>
188+
Vector<ExecConfig> makeNExecutionConfigs (Long N) noexcept
189+
{
190+
// Max # of blocks in a kernel launch
191+
int numblocks_max = std::numeric_limits<int>::max();
192+
// Max # of threads in a kernel launch
193+
Long nmax = Long(MT) * numblocks_max;
194+
// # of launches needed for N elements without using grid-stride
195+
// loops inside GPU kernels.
196+
auto nlaunches = int((N+nmax-1)/nmax);
197+
Vector<ExecConfig> r(nlaunches);
198+
for (int i = 0; i < nlaunches; ++i) {
199+
int nblocks;
200+
if (N > nmax) {
201+
nblocks = numblocks_max;
202+
N -= nmax;
203+
} else {
204+
nblocks = int((N+MT-1)/MT);
205+
}
206+
// Total # of threads in this launch
207+
r[i].ntotalthreads = Long(nblocks) * MT;
208+
// # of blocks in this launch
209+
r[i].nblocks = nblocks;
210+
}
211+
return r;
212+
}
213+
214+
template <int MT, int dim>
215+
Vector<ExecConfig> makeNExecutionConfigs (BoxND<dim> const& box) noexcept
216+
{
217+
return makeNExecutionConfigs<MT>(box.numPts());
218+
}
179219
#endif
180220

181221
}

Src/Base/AMReX_GpuLaunchFunctsG.H

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -747,35 +747,76 @@ void launch (int nblocks, int nthreads_per_block, gpuStream_t stream, L&& f) noe
747747
launch(nblocks, nthreads_per_block, 0, stream, std::forward<L>(f));
748748
}
749749

750-
template<int MT, typename T, typename L>
750+
template<int MT, typename T, typename L, std::enable_if_t<std::is_integral_v<T>,int> FOO = 0>
751751
void launch (T const& n, L const& f) noexcept
752752
{
753+
static_assert(sizeof(T) >= 2);
753754
if (amrex::isEmpty(n)) { return; }
754-
const auto ec = Gpu::makeExecutionConfig<MT>(n);
755-
AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(),
756-
[=] AMREX_GPU_DEVICE () noexcept {
757-
for (auto const i : Gpu::Range(n)) {
758-
f(i);
755+
const auto& nec = Gpu::makeNExecutionConfigs<MT>(n);
756+
T ndone = 0;
757+
for (auto const& ec : nec) {
758+
T nleft = n - ndone;
759+
AMREX_LAUNCH_KERNEL(MT, ec.nblocks, MT, 0, Gpu::gpuStream(),
760+
[=] AMREX_GPU_DEVICE () noexcept {
761+
// This will not overflow, even though nblocks*MT might.
762+
auto tid = T(MT)*T(blockIdx.x)+T(threadIdx.x);
763+
if (tid < nleft) {
764+
f(tid+ndone);
765+
}
766+
});
767+
if (nleft > ec.ntotalthreads) {
768+
ndone += T(ec.ntotalthreads);
759769
}
760-
});
770+
}
771+
AMREX_GPU_ERROR_CHECK();
772+
}
773+
774+
template<int MT, int dim, typename L>
775+
void launch (BoxND<dim> const& box, L const& f) noexcept
776+
{
777+
if (box.isEmpty()) { return; }
778+
const auto& nec = Gpu::makeNExecutionConfigs<MT>(box);
779+
const BoxIndexerND<dim> indexer(box);
780+
const auto type = box.ixType();
781+
std::uint64_t ndone = 0;
782+
for (auto const& ec : nec) {
783+
AMREX_LAUNCH_KERNEL(MT, ec.nblocks, MT, 0, Gpu::gpuStream(),
784+
[=] AMREX_GPU_DEVICE () noexcept {
785+
auto icell = std::uint64_t(MT)*blockIdx.x+threadIdx.x + ndone;
786+
if (icell < indexer.numPts()) {
787+
auto iv = indexer.intVect(icell);
788+
f(BoxND<dim>(iv,iv,type));
789+
}
790+
});
791+
ndone += ec.ntotalthreads;
792+
}
761793
AMREX_GPU_ERROR_CHECK();
762794
}
763795

764796
template <int MT, typename T, typename L, typename M=std::enable_if_t<std::is_integral<T>::value> >
765797
std::enable_if_t<MaybeDeviceRunnable<L>::value>
766798
ParallelFor (Gpu::KernelInfo const&, T n, L const& f) noexcept
767799
{
800+
static_assert(sizeof(T) >= 2);
768801
if (amrex::isEmpty(n)) { return; }
769-
const auto ec = Gpu::makeExecutionConfig<MT>(n);
770-
AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(),
771-
[=] AMREX_GPU_DEVICE () noexcept {
772-
for (Long i = Long(blockDim.x)*blockIdx.x+threadIdx.x, stride = Long(blockDim.x)*gridDim.x;
773-
i < Long(n); i += stride) {
774-
detail::call_f_scalar_handler(f, T(i),
775-
Gpu::Handler(amrex::min((std::uint64_t(n)-i+(std::uint64_t)threadIdx.x),
776-
(std::uint64_t)blockDim.x)));
802+
const auto& nec = Gpu::makeNExecutionConfigs<MT>(n);
803+
T ndone = 0;
804+
for (auto const& ec : nec) {
805+
T nleft = n - ndone;
806+
AMREX_LAUNCH_KERNEL(MT, ec.nblocks, MT, 0, Gpu::gpuStream(),
807+
[=] AMREX_GPU_DEVICE () noexcept {
808+
// This will not overflow, even though nblocks*MT might.
809+
auto tid = T(MT)*T(blockIdx.x)+T(threadIdx.x);
810+
if (tid < nleft) {
811+
detail::call_f_scalar_handler(f, tid+ndone,
812+
Gpu::Handler(amrex::min((std::uint64_t(nleft)-tid+(std::uint64_t)threadIdx.x),
813+
(std::uint64_t)blockDim.x)));
814+
}
815+
});
816+
if (nleft > ec.ntotalthreads) {
817+
ndone += ec.ntotalthreads;
777818
}
778-
});
819+
}
779820
AMREX_GPU_ERROR_CHECK();
780821
}
781822

@@ -785,18 +826,21 @@ ParallelFor (Gpu::KernelInfo const&, BoxND<dim> const& box, L const& f) noexcept
785826
{
786827
if (amrex::isEmpty(box)) { return; }
787828
const BoxIndexerND<dim> indexer(box);
788-
const auto ec = Gpu::makeExecutionConfig<MT>(box.numPts());
789-
AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(),
790-
[=] AMREX_GPU_DEVICE () noexcept {
791-
for (std::uint64_t icell = std::uint64_t(blockDim.x)*blockIdx.x+threadIdx.x, stride = std::uint64_t(blockDim.x)*gridDim.x;
792-
icell < indexer.numPts(); icell += stride)
793-
{
794-
auto iv = indexer.intVect(icell);
795-
detail::call_f_intvect_handler(f, iv,
796-
Gpu::Handler(amrex::min((indexer.numPts()-icell+(std::uint64_t)threadIdx.x),
797-
(std::uint64_t)blockDim.x)));
798-
}
799-
});
829+
const auto& nec = Gpu::makeNExecutionConfigs<MT>(box);
830+
std::uint64_t ndone = 0;
831+
for (auto const& ec : nec) {
832+
AMREX_LAUNCH_KERNEL(MT, ec.nblocks, MT, 0, Gpu::gpuStream(),
833+
[=] AMREX_GPU_DEVICE () noexcept {
834+
auto icell = std::uint64_t(MT)*blockIdx.x+threadIdx.x + ndone;
835+
if (icell < indexer.numPts()) {
836+
auto iv = indexer.intVect(icell);
837+
detail::call_f_intvect_handler(f, iv,
838+
Gpu::Handler(amrex::min((indexer.numPts()-icell+(std::uint64_t)threadIdx.x),
839+
(std::uint64_t)blockDim.x)));
840+
}
841+
});
842+
ndone += ec.ntotalthreads;
843+
}
800844
AMREX_GPU_ERROR_CHECK();
801845
}
802846

@@ -806,17 +850,21 @@ ParallelFor (Gpu::KernelInfo const&, BoxND<dim> const& box, T ncomp, L const& f)
806850
{
807851
if (amrex::isEmpty(box)) { return; }
808852
const BoxIndexerND<dim> indexer(box);
809-
const auto ec = Gpu::makeExecutionConfig<MT>(box.numPts());
810-
AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(),
811-
[=] AMREX_GPU_DEVICE () noexcept {
812-
for (std::uint64_t icell = std::uint64_t(blockDim.x)*blockIdx.x+threadIdx.x, stride = std::uint64_t(blockDim.x)*gridDim.x;
813-
icell < indexer.numPts(); icell += stride) {
814-
auto iv = indexer.intVect(icell);
815-
detail::call_f_intvect_ncomp_handler(f, iv, ncomp,
816-
Gpu::Handler(amrex::min((indexer.numPts()-icell+(std::uint64_t)threadIdx.x),
817-
(std::uint64_t)blockDim.x)));
818-
}
819-
});
853+
const auto& nec = Gpu::makeNExecutionConfigs<MT>(box);
854+
std::uint64_t ndone = 0;
855+
for (auto const& ec : nec) {
856+
AMREX_LAUNCH_KERNEL(MT, ec.nblocks, MT, 0, Gpu::gpuStream(),
857+
[=] AMREX_GPU_DEVICE () noexcept {
858+
auto icell = std::uint64_t(MT)*blockIdx.x+threadIdx.x + ndone;
859+
if (icell < indexer.numPts()) {
860+
auto iv = indexer.intVect(icell);
861+
detail::call_f_intvect_ncomp_handler(f, iv, ncomp,
862+
Gpu::Handler(amrex::min((indexer.numPts()-icell+(std::uint64_t)threadIdx.x),
863+
(std::uint64_t)blockDim.x)));
864+
}
865+
});
866+
ndone += ec.ntotalthreads;
867+
}
820868
AMREX_GPU_ERROR_CHECK();
821869
}
822870

0 commit comments

Comments
 (0)