Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions Docs/sphinx_documentation/source/Basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ Initialize and Finalize
As we have mentioned, :cpp:`Initialize` must be called to initialize
the execution environment for AMReX and :cpp:`Finalize` must be paired
with :cpp:`Initialize` to release the resources used by AMReX. There
are two versions of :cpp:`Initialize`.
are three versions of :cpp:`Initialize`.

.. highlight:: c++

Expand All @@ -727,20 +727,23 @@ are two versions of :cpp:`Initialize`.
void Initialize (MPI_Comm mpi_comm,
std::ostream& a_osout = std::cout,
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);
ErrorHandler a_errhandler = nullptr,
int a_device_id = -1);

AMReX* Initialize (int& argc, char**& argv,
const std::function<void()>& func_parm_parse,
std::ostream& a_osout = std::cout,
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);
ErrorHandler a_errhandler = nullptr,
int a_device_id = -1);

void Initialize (int& argc, char**& argv, bool build_parm_parse=true,
MPI_Comm mpi_comm = MPI_COMM_WORLD,
const std::function<void()>& func_parm_parse = {},
std::ostream& a_osout = std::cout,
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);
ErrorHandler a_errhandler = nullptr,
int a_device_id = -1);

:cpp:`Initialize` checks if MPI has been initialized. If it has, AMReX will
duplicate the ``MPI_Comm`` argument provided by the users in the first and
Expand Down Expand Up @@ -768,6 +771,16 @@ second and third versions, the user may also pass a function that adds
parameters to the ParmParse database instead of reading from command line or
input file.

The last optional parameter, :cpp:`int a_device_id = -1`, applies to
GPU builds only. By default, when multiple GPU devices are visible, AMReX
automatically selects one for you. In most cases, you should rely on this
default behavior and omit the optional argument. However, if another library
has already been initialized and assigned processes to specific devices, you
may need AMReX to use a particular GPU. In that case, you can pass the
desired device ID to :cpp:`amrex::Initialize`. Conversely, if you want
another library to use the device selected by AMReX, you can obtain the
device ID by calling :cpp:`int amrex::Gpu::Device::deviceId()`.

Because many AMReX classes and functions (including destructors
inserted by the compiler) do not function properly after
:cpp:`amrex::Finalize` is called, it's best to put the codes between
Expand Down
9 changes: 6 additions & 3 deletions Src/Base/AMReX.H
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,25 @@ namespace amrex
AMReX* Initialize (MPI_Comm mpi_comm,
std::ostream& a_osout = std::cout,
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);
ErrorHandler a_errhandler = nullptr,
int a_device_id = -1);

// The returned AMReX* is non-owning! To delete it, call Finalize(AMReX*).
AMReX* Initialize (int& argc, char**& argv,
const std::function<void()>& func_parm_parse,
std::ostream& a_osout = std::cout,
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);
ErrorHandler a_errhandler = nullptr,
int a_device_id = -1);

// The returned AMReX* is non-owning! To delete it, call Finalize(AMReX*).
AMReX* Initialize (int& argc, char**& argv, bool build_parm_parse=true,
MPI_Comm mpi_comm = MPI_COMM_WORLD,
const std::function<void()>& func_parm_parse = {},
std::ostream& a_osout = std::cout,
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);
ErrorHandler a_errhandler = nullptr,
int a_device_id = -1);

// \brief Minimal version of initialization.
//
Expand Down
14 changes: 8 additions & 6 deletions Src/Base/AMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,28 +331,29 @@ amrex::ExecOnInitialize (std::function<void()> f)
amrex::AMReX*
amrex::Initialize (MPI_Comm mpi_comm,
std::ostream& a_osout, std::ostream& a_oserr,
ErrorHandler a_errhandler)
ErrorHandler a_errhandler, int a_device_id)
{
int argc = 0;
char** argv = nullptr;
return Initialize(argc, argv, false, mpi_comm, {}, a_osout, a_oserr, a_errhandler);
return Initialize(argc, argv, false, mpi_comm, {}, a_osout, a_oserr,
a_errhandler, a_device_id);
}

amrex::AMReX*
amrex::Initialize (int& argc, char**& argv,
const std::function<void()>& func_parm_parse,
std::ostream& a_osout, std::ostream& a_oserr,
ErrorHandler a_errhandler)
ErrorHandler a_errhandler, int a_device_id)
{
return Initialize(argc, argv, true, MPI_COMM_WORLD, func_parm_parse,
a_osout, a_oserr, a_errhandler);
a_osout, a_oserr, a_errhandler, a_device_id);
}

amrex::AMReX*
amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
MPI_Comm mpi_comm, const std::function<void()>& func_parm_parse,
std::ostream& a_osout, std::ostream& a_oserr,
ErrorHandler a_errhandler)
ErrorHandler a_errhandler, int a_device_id)
{
system::exename.clear();
if (initialization_by_init_minimal) {
Expand Down Expand Up @@ -541,9 +542,10 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,

Machine::Initialize();

amrex::ignore_unused(a_device_id);
#ifdef AMREX_USE_GPU
// Initialize after ParmParse so that we can read inputs.
Gpu::Device::Initialize(initialization_by_init_minimal);
Gpu::Device::Initialize(initialization_by_init_minimal, a_device_id);
#ifdef AMREX_USE_CUPTI
CuptiInitialize();
#endif
Expand Down
2 changes: 1 addition & 1 deletion Src/Base/AMReX_GpuDevice.H
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Device

public:

static void Initialize (bool minimal);
static void Initialize (bool minimal, int a_device_id);
static void Finalize ();

#if defined(AMREX_USE_GPU)
Expand Down
8 changes: 5 additions & 3 deletions Src/Base/AMReX_GpuDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ namespace {
#endif

void
Device::Initialize (bool minimal)
Device::Initialize (bool minimal, int a_device_id)
{
amrex::ignore_unused(minimal);
amrex::ignore_unused(minimal, a_device_id);
#ifdef AMREX_USE_GPU

#if defined(AMREX_USE_CUDA) && (defined(AMREX_PROFILING) || defined(AMREX_TINY_PROFILING))
Expand Down Expand Up @@ -203,6 +203,8 @@ Device::Initialize (bool minimal)
device_id = 0;
AMREX_HIP_OR_CUDA(AMREX_HIP_SAFE_CALL (hipGetDevice(&device_id));,
AMREX_CUDA_SAFE_CALL(cudaGetDevice(&device_id)); );
} else if (a_device_id >= 0) {
device_id = a_device_id;
} else if (ParallelDescriptor::NProcs() == 1) {
device_id = 0;
}
Expand All @@ -219,7 +221,7 @@ Device::Initialize (bool minimal)
}
}

if (gpu_device_count > 1 && ! minimal) {
if (gpu_device_count > 1 && ! minimal && a_device_id < 0) {
if (Machine::name() == "nersc.perlmutter") {
// The CPU/GPU mapping on perlmutter has the reverse order.
device_id = gpu_device_count - device_id - 1;
Expand Down
Loading