@@ -747,35 +747,76 @@ void launch (int nblocks, int nthreads_per_block, gpuStream_t stream, L&& f) noe
747747 launch (nblocks, nthreads_per_block, 0 , stream, std::forward<L>(f));
748748}
749749
750- template <int MT, typename T, typename L>
750+ template <int MT, typename T, typename L, std:: enable_if_t <std::is_integral_v<T>, int > FOO = 0 >
751751void launch (T const & n, L const & f) noexcept
752752{
753+ static_assert (sizeof (T) >= 2 );
753754 if (amrex::isEmpty (n)) { return ; }
754- const auto ec = Gpu::makeExecutionConfig<MT>(n);
755- AMREX_LAUNCH_KERNEL (MT, ec.numBlocks , ec.numThreads , 0 , Gpu::gpuStream (),
756- [=] AMREX_GPU_DEVICE () noexcept {
757- for (auto const i : Gpu::Range (n)) {
758- f (i);
755+ const auto & nec = Gpu::makeNExecutionConfigs<MT>(n);
756+ T ndone = 0 ;
757+ for (auto const & ec : nec) {
758+ T nleft = n - ndone;
759+ AMREX_LAUNCH_KERNEL (MT, ec.nblocks , MT, 0 , Gpu::gpuStream (),
760+ [=] AMREX_GPU_DEVICE () noexcept {
761+ // This will not overflow, even though nblocks*MT might.
762+ auto tid = T (MT)*T (blockIdx.x )+T (threadIdx.x );
763+ if (tid < nleft) {
764+ f (tid+ndone);
765+ }
766+ });
767+ if (nleft > ec.ntotalthreads ) {
768+ ndone += T (ec.ntotalthreads );
759769 }
760- });
770+ }
771+ AMREX_GPU_ERROR_CHECK ();
772+ }
773+
774+ template <int MT, int dim, typename L>
775+ void launch (BoxND<dim> const & box, L const & f) noexcept
776+ {
777+ if (box.isEmpty ()) { return ; }
778+ const auto & nec = Gpu::makeNExecutionConfigs<MT>(box);
779+ const BoxIndexerND<dim> indexer (box);
780+ const auto type = box.ixType ();
781+ std::uint64_t ndone = 0 ;
782+ for (auto const & ec : nec) {
783+ AMREX_LAUNCH_KERNEL (MT, ec.nblocks , MT, 0 , Gpu::gpuStream (),
784+ [=] AMREX_GPU_DEVICE () noexcept {
785+ auto icell = std::uint64_t (MT)*blockIdx.x +threadIdx.x + ndone;
786+ if (icell < indexer.numPts ()) {
787+ auto iv = indexer.intVect (icell);
788+ f (BoxND<dim>(iv,iv,type));
789+ }
790+ });
791+ ndone += ec.ntotalthreads ;
792+ }
761793 AMREX_GPU_ERROR_CHECK ();
762794}
763795
764796template <int MT, typename T, typename L, typename M=std::enable_if_t <std::is_integral<T>::value> >
765797std::enable_if_t <MaybeDeviceRunnable<L>::value>
766798ParallelFor (Gpu::KernelInfo const &, T n, L const & f) noexcept
767799{
800+ static_assert (sizeof (T) >= 2 );
768801 if (amrex::isEmpty (n)) { return ; }
769- const auto ec = Gpu::makeExecutionConfig<MT>(n);
770- AMREX_LAUNCH_KERNEL (MT, ec.numBlocks , ec.numThreads , 0 , Gpu::gpuStream (),
771- [=] AMREX_GPU_DEVICE () noexcept {
772- for (Long i = Long (blockDim.x )*blockIdx.x +threadIdx.x , stride = Long (blockDim.x )*gridDim.x ;
773- i < Long (n); i += stride) {
774- detail::call_f_scalar_handler (f, T (i),
775- Gpu::Handler (amrex::min ((std::uint64_t (n)-i+(std::uint64_t )threadIdx.x ),
776- (std::uint64_t )blockDim.x )));
802+ const auto & nec = Gpu::makeNExecutionConfigs<MT>(n);
803+ T ndone = 0 ;
804+ for (auto const & ec : nec) {
805+ T nleft = n - ndone;
806+ AMREX_LAUNCH_KERNEL (MT, ec.nblocks , MT, 0 , Gpu::gpuStream (),
807+ [=] AMREX_GPU_DEVICE () noexcept {
808+ // This will not overflow, even though nblocks*MT might.
809+ auto tid = T (MT)*T (blockIdx.x )+T (threadIdx.x );
810+ if (tid < nleft) {
811+ detail::call_f_scalar_handler (f, tid+ndone,
812+ Gpu::Handler (amrex::min ((std::uint64_t (nleft)-tid+(std::uint64_t )threadIdx.x ),
813+ (std::uint64_t )blockDim.x )));
814+ }
815+ });
816+ if (nleft > ec.ntotalthreads ) {
817+ ndone += ec.ntotalthreads ;
777818 }
778- });
819+ }
779820 AMREX_GPU_ERROR_CHECK ();
780821}
781822
@@ -785,18 +826,21 @@ ParallelFor (Gpu::KernelInfo const&, BoxND<dim> const& box, L const& f) noexcept
785826{
786827 if (amrex::isEmpty (box)) { return ; }
787828 const BoxIndexerND<dim> indexer (box);
788- const auto ec = Gpu::makeExecutionConfig<MT>(box.numPts ());
789- AMREX_LAUNCH_KERNEL (MT, ec.numBlocks , ec.numThreads , 0 , Gpu::gpuStream (),
790- [=] AMREX_GPU_DEVICE () noexcept {
791- for (std::uint64_t icell = std::uint64_t (blockDim.x )*blockIdx.x +threadIdx.x , stride = std::uint64_t (blockDim.x )*gridDim.x ;
792- icell < indexer.numPts (); icell += stride)
793- {
794- auto iv = indexer.intVect (icell);
795- detail::call_f_intvect_handler (f, iv,
796- Gpu::Handler (amrex::min ((indexer.numPts ()-icell+(std::uint64_t )threadIdx.x ),
797- (std::uint64_t )blockDim.x )));
798- }
799- });
829+ const auto & nec = Gpu::makeNExecutionConfigs<MT>(box);
830+ std::uint64_t ndone = 0 ;
831+ for (auto const & ec : nec) {
832+ AMREX_LAUNCH_KERNEL (MT, ec.nblocks , MT, 0 , Gpu::gpuStream (),
833+ [=] AMREX_GPU_DEVICE () noexcept {
834+ auto icell = std::uint64_t (MT)*blockIdx.x +threadIdx.x + ndone;
835+ if (icell < indexer.numPts ()) {
836+ auto iv = indexer.intVect (icell);
837+ detail::call_f_intvect_handler (f, iv,
838+ Gpu::Handler (amrex::min ((indexer.numPts ()-icell+(std::uint64_t )threadIdx.x ),
839+ (std::uint64_t )blockDim.x )));
840+ }
841+ });
842+ ndone += ec.ntotalthreads ;
843+ }
800844 AMREX_GPU_ERROR_CHECK ();
801845}
802846
@@ -806,17 +850,21 @@ ParallelFor (Gpu::KernelInfo const&, BoxND<dim> const& box, T ncomp, L const& f)
806850{
807851 if (amrex::isEmpty (box)) { return ; }
808852 const BoxIndexerND<dim> indexer (box);
809- const auto ec = Gpu::makeExecutionConfig<MT>(box.numPts ());
810- AMREX_LAUNCH_KERNEL (MT, ec.numBlocks , ec.numThreads , 0 , Gpu::gpuStream (),
811- [=] AMREX_GPU_DEVICE () noexcept {
812- for (std::uint64_t icell = std::uint64_t (blockDim.x )*blockIdx.x +threadIdx.x , stride = std::uint64_t (blockDim.x )*gridDim.x ;
813- icell < indexer.numPts (); icell += stride) {
814- auto iv = indexer.intVect (icell);
815- detail::call_f_intvect_ncomp_handler (f, iv, ncomp,
816- Gpu::Handler (amrex::min ((indexer.numPts ()-icell+(std::uint64_t )threadIdx.x ),
817- (std::uint64_t )blockDim.x )));
818- }
819- });
853+ const auto & nec = Gpu::makeNExecutionConfigs<MT>(box);
854+ std::uint64_t ndone = 0 ;
855+ for (auto const & ec : nec) {
856+ AMREX_LAUNCH_KERNEL (MT, ec.nblocks , MT, 0 , Gpu::gpuStream (),
857+ [=] AMREX_GPU_DEVICE () noexcept {
858+ auto icell = std::uint64_t (MT)*blockIdx.x +threadIdx.x + ndone;
859+ if (icell < indexer.numPts ()) {
860+ auto iv = indexer.intVect (icell);
861+ detail::call_f_intvect_ncomp_handler (f, iv, ncomp,
862+ Gpu::Handler (amrex::min ((indexer.numPts ()-icell+(std::uint64_t )threadIdx.x ),
863+ (std::uint64_t )blockDim.x )));
864+ }
865+ });
866+ ndone += ec.ntotalthreads ;
867+ }
820868 AMREX_GPU_ERROR_CHECK ();
821869}
822870
0 commit comments