Skip to content

Commit 159ed0b

Browse files
WeiqunZhangatmyers
andauthored
amrex::Initialize: Add optional argument of device ID (#4741)
In most cases, you should let AMReX select a device when multiple GPU devices are visible. However, if another library has already been initialized and assigned processes to specific devices, you may need AMReX to use a particular GPU. In that case, you can pass the desired device ID to `amrex::Initialize`. --------- Co-authored-by: Andrew Myers <[email protected]>
1 parent bab255c commit 159ed0b

File tree

5 files changed

+37
-17
lines changed

5 files changed

+37
-17
lines changed

Docs/sphinx_documentation/source/Basics.rst

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ Initialize and Finalize
718718
As we have mentioned, :cpp:`Initialize` must be called to initialize
719719
the execution environment for AMReX and :cpp:`Finalize` must be paired
720720
with :cpp:`Initialize` to release the resources used by AMReX. There
721-
are two versions of :cpp:`Initialize`.
721+
are three versions of :cpp:`Initialize`.
722722

723723
.. highlight:: c++
724724

@@ -727,20 +727,23 @@ are two versions of :cpp:`Initialize`.
727727
void Initialize (MPI_Comm mpi_comm,
728728
std::ostream& a_osout = std::cout,
729729
std::ostream& a_oserr = std::cerr,
730-
ErrorHandler a_errhandler = nullptr);
730+
ErrorHandler a_errhandler = nullptr,
731+
int a_device_id = -1);
731732

732733
AMReX* Initialize (int& argc, char**& argv,
733734
const std::function<void()>& func_parm_parse,
734735
std::ostream& a_osout = std::cout,
735736
std::ostream& a_oserr = std::cerr,
736-
ErrorHandler a_errhandler = nullptr);
737+
ErrorHandler a_errhandler = nullptr,
738+
int a_device_id = -1);
737739

738740
void Initialize (int& argc, char**& argv, bool build_parm_parse=true,
739741
MPI_Comm mpi_comm = MPI_COMM_WORLD,
740742
const std::function<void()>& func_parm_parse = {},
741743
std::ostream& a_osout = std::cout,
742744
std::ostream& a_oserr = std::cerr,
743-
ErrorHandler a_errhandler = nullptr);
745+
ErrorHandler a_errhandler = nullptr,
746+
int a_device_id = -1);
744747

745748
:cpp:`Initialize` checks if MPI has been initialized. If it has, AMReX will
746749
duplicate the ``MPI_Comm`` argument provided by the users in the first and
@@ -768,6 +771,16 @@ second and third versions, the user may also pass a function that adds
768771
parameters to the ParmParse database instead of reading from command line or
769772
input file.
770773

774+
The last optional parameter, :cpp:`int a_device_id = -1`, applies to
775+
GPU builds only. By default, when multiple GPU devices are visible, AMReX
776+
automatically selects one for you. In most cases, you should rely on this
777+
default behavior and omit the optional argument. However, if another library
778+
has already been initialized and assigned processes to specific devices, you
779+
may need AMReX to use a particular GPU. In that case, you can pass the
780+
desired device ID to :cpp:`amrex::Initialize`. Conversely, if you want
781+
another library to use the device selected by AMReX, you can obtain the
782+
device ID by calling :cpp:`int amrex::Gpu::Device::deviceId()`.
783+
771784
Because many AMReX classes and functions (including destructors
772785
inserted by the compiler) do not function properly after
773786
:cpp:`amrex::Finalize` is called, it's best to put the codes between

Src/Base/AMReX.H

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,25 @@ namespace amrex
7878
AMReX* Initialize (MPI_Comm mpi_comm,
7979
std::ostream& a_osout = std::cout,
8080
std::ostream& a_oserr = std::cerr,
81-
ErrorHandler a_errhandler = nullptr);
81+
ErrorHandler a_errhandler = nullptr,
82+
int a_device_id = -1);
8283

8384
// The returned AMReX* is non-owning! To delete it, call Finalize(AMReX*).
8485
AMReX* Initialize (int& argc, char**& argv,
8586
const std::function<void()>& func_parm_parse,
8687
std::ostream& a_osout = std::cout,
8788
std::ostream& a_oserr = std::cerr,
88-
ErrorHandler a_errhandler = nullptr);
89+
ErrorHandler a_errhandler = nullptr,
90+
int a_device_id = -1);
8991

9092
// The returned AMReX* is non-owning! To delete it, call Finalize(AMReX*).
9193
AMReX* Initialize (int& argc, char**& argv, bool build_parm_parse=true,
9294
MPI_Comm mpi_comm = MPI_COMM_WORLD,
9395
const std::function<void()>& func_parm_parse = {},
9496
std::ostream& a_osout = std::cout,
9597
std::ostream& a_oserr = std::cerr,
96-
ErrorHandler a_errhandler = nullptr);
98+
ErrorHandler a_errhandler = nullptr,
99+
int a_device_id = -1);
97100

98101
// \brief Minimal version of initialization.
99102
//

Src/Base/AMReX.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,28 +331,29 @@ amrex::ExecOnInitialize (std::function<void()> f)
331331
amrex::AMReX*
332332
amrex::Initialize (MPI_Comm mpi_comm,
333333
std::ostream& a_osout, std::ostream& a_oserr,
334-
ErrorHandler a_errhandler)
334+
ErrorHandler a_errhandler, int a_device_id)
335335
{
336336
int argc = 0;
337337
char** argv = nullptr;
338-
return Initialize(argc, argv, false, mpi_comm, {}, a_osout, a_oserr, a_errhandler);
338+
return Initialize(argc, argv, false, mpi_comm, {}, a_osout, a_oserr,
339+
a_errhandler, a_device_id);
339340
}
340341

341342
amrex::AMReX*
342343
amrex::Initialize (int& argc, char**& argv,
343344
const std::function<void()>& func_parm_parse,
344345
std::ostream& a_osout, std::ostream& a_oserr,
345-
ErrorHandler a_errhandler)
346+
ErrorHandler a_errhandler, int a_device_id)
346347
{
347348
return Initialize(argc, argv, true, MPI_COMM_WORLD, func_parm_parse,
348-
a_osout, a_oserr, a_errhandler);
349+
a_osout, a_oserr, a_errhandler, a_device_id);
349350
}
350351

351352
amrex::AMReX*
352353
amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
353354
MPI_Comm mpi_comm, const std::function<void()>& func_parm_parse,
354355
std::ostream& a_osout, std::ostream& a_oserr,
355-
ErrorHandler a_errhandler)
356+
ErrorHandler a_errhandler, int a_device_id)
356357
{
357358
system::exename.clear();
358359
if (initialization_by_init_minimal) {
@@ -541,9 +542,10 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
541542

542543
Machine::Initialize();
543544

545+
amrex::ignore_unused(a_device_id);
544546
#ifdef AMREX_USE_GPU
545547
// Initialize after ParmParse so that we can read inputs.
546-
Gpu::Device::Initialize(initialization_by_init_minimal);
548+
Gpu::Device::Initialize(initialization_by_init_minimal, a_device_id);
547549
#ifdef AMREX_USE_CUPTI
548550
CuptiInitialize();
549551
#endif

Src/Base/AMReX_GpuDevice.H

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Device
5353

5454
public:
5555

56-
static void Initialize (bool minimal);
56+
static void Initialize (bool minimal, int a_device_id);
5757
static void Finalize ();
5858

5959
#if defined(AMREX_USE_GPU)

Src/Base/AMReX_GpuDevice.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ namespace {
142142
#endif
143143

144144
void
145-
Device::Initialize (bool minimal)
145+
Device::Initialize (bool minimal, int a_device_id)
146146
{
147-
amrex::ignore_unused(minimal);
147+
amrex::ignore_unused(minimal, a_device_id);
148148
#ifdef AMREX_USE_GPU
149149

150150
#if defined(AMREX_USE_CUDA) && (defined(AMREX_PROFILING) || defined(AMREX_TINY_PROFILING))
@@ -203,6 +203,8 @@ Device::Initialize (bool minimal)
203203
device_id = 0;
204204
AMREX_HIP_OR_CUDA(AMREX_HIP_SAFE_CALL (hipGetDevice(&device_id));,
205205
AMREX_CUDA_SAFE_CALL(cudaGetDevice(&device_id)); );
206+
} else if (a_device_id >= 0) {
207+
device_id = a_device_id;
206208
} else if (ParallelDescriptor::NProcs() == 1) {
207209
device_id = 0;
208210
}
@@ -219,7 +221,7 @@ Device::Initialize (bool minimal)
219221
}
220222
}
221223

222-
if (gpu_device_count > 1 && ! minimal) {
224+
if (gpu_device_count > 1 && ! minimal && a_device_id < 0) {
223225
if (Machine::name() == "nersc.perlmutter") {
224226
// The CPU/GPU mapping on perlmutter has the reverse order.
225227
device_id = gpu_device_count - device_id - 1;

0 commit comments

Comments
 (0)