Skip to content

Commit b942834

Browse files
authored
Revise the mscclpp datatype (#671)
Use mscclpp::DataType to replace the following types in API interface: int dtype; ncclDataType_t dtype; Add data type conversion: Convert ncclDataType_t to mscclpp::DataType
1 parent a19bca9 commit b942834

File tree

14 files changed

+250
-173
lines changed

14 files changed

+250
-173
lines changed

apps/nccl/src/allgather.cu

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <mscclpp/gpu_utils.hpp>
99

1010
#include "allgather.hpp"
11+
#include "datatype_conversion.hpp"
1112
#include "debug.h"
1213

1314
AllgatherAlgo6::AllgatherAlgo6() : disableChannelCache_(false) {
@@ -23,10 +24,11 @@ void AllgatherAlgo6::initialize(std::shared_ptr<mscclpp::Communicator> comm,
2324
}
2425

2526
ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
26-
void* output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
27+
void* output, size_t count, mscclpp::DataType dtype,
28+
cudaStream_t stream,
2729
std::unordered_map<std::string, std::shared_ptr<void>>&) {
2830
int nBlocks = 28;
29-
const size_t bytes = count * ncclTypeSize(dtype);
31+
const size_t bytes = count * getDataTypeSize(dtype);
3032
const size_t nElem = bytes / sizeof(int);
3133
int rank = ctx->rank;
3234
if (bytes <= 32 * (1 << 20)) {
@@ -61,7 +63,7 @@ ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::
6163

6264
std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
6365
const void*, void* output, size_t count,
64-
ncclDataType_t dtype) {
66+
mscclpp::DataType dtype) {
6567
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
6668
ctx->rank = comm->bootstrap()->getRank();
6769
ctx->workSize = comm->bootstrap()->getNranks();
@@ -70,7 +72,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std:
7072
// setup semaphores
7173
ctx->memorySemaphores = this->memorySemaphores_;
7274

73-
size_t bytes = count * ncclTypeSize(dtype);
75+
size_t bytes = count * getDataTypeSize(dtype);
7476
size_t recvBytes;
7577
CUdeviceptr recvBasePtr;
7678
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
@@ -98,7 +100,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std:
98100
}
99101

100102
mscclpp::AlgorithmCtxKey AllgatherAlgo6::generateAllgatherContextKey(const void*, void* output, size_t,
101-
ncclDataType_t) {
103+
mscclpp::DataType) {
102104
static int tag = 0;
103105
if (disableChannelCache_) {
104106
// always return a new key if channel cache is disabled
@@ -116,15 +118,15 @@ mscclpp::Algorithm AllgatherAlgo6::build() {
116118
"default_allgather6", "allgather",
117119
[self](std::shared_ptr<mscclpp::Communicator> comm,
118120
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
119-
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
120-
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
121-
return self->allgatherKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
121+
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count,
122+
mscclpp::DataType dtype, cudaStream_t stream,
123+
std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
124+
return self->allgatherKernelFunc(ctx, input, output, count, dtype, stream, extras);
122125
},
123-
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
124-
return self->initAllgatherContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
125-
},
126-
[self](const void* input, void* output, size_t count, int dtype) {
127-
return self->generateAllgatherContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
126+
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count,
127+
mscclpp::DataType dtype) { return self->initAllgatherContext(comm, input, output, count, dtype); },
128+
[self](const void* input, void* output, size_t count, mscclpp::DataType dtype) {
129+
return self->generateAllgatherContextKey(input, output, count, dtype);
128130
});
129131
return allgatherAlgo;
130132
}
@@ -137,10 +139,11 @@ void AllgatherAlgo8::initialize(std::shared_ptr<mscclpp::Communicator> comm,
137139
}
138140

139141
ncclResult_t AllgatherAlgo8::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
140-
void* output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
142+
void* output, size_t count, mscclpp::DataType dtype,
143+
cudaStream_t stream,
141144
std::unordered_map<std::string, std::shared_ptr<void>>&) {
142145
int rank = ctx->rank;
143-
const size_t bytes = count * ncclTypeSize(dtype);
146+
const size_t bytes = count * getDataTypeSize(dtype);
144147
const size_t nElem = bytes / sizeof(int);
145148
if ((char*)input == (char*)output + rank * bytes) {
146149
allgather8<false><<<56, 1024, 0, stream>>>((void*)input, this->scratchBuffer_.get(), (void*)output,
@@ -161,7 +164,7 @@ ncclResult_t AllgatherAlgo8::allgatherKernelFunc(const std::shared_ptr<mscclpp::
161164

162165
std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
163166
const void* input, void*, size_t count,
164-
ncclDataType_t dtype) {
167+
mscclpp::DataType dtype) {
165168
constexpr int nChannelsPerConnection = 56;
166169

167170
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
@@ -172,7 +175,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std:
172175
// setup semaphores
173176
ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection));
174177

175-
size_t bytes = count * ncclTypeSize(dtype);
178+
size_t bytes = count * getDataTypeSize(dtype);
176179
// register the memory for the broadcast operation
177180
mscclpp::RegisteredMemory localMemory = comm->registerMemory((void*)input, bytes, mscclpp::Transport::CudaIpc);
178181
mscclpp::RegisteredMemory scratchMemory =
@@ -192,7 +195,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std:
192195
return ctx;
193196
}
194197

195-
mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t) {
198+
mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey(const void*, void*, size_t, mscclpp::DataType) {
196199
// always return same key, non-zero copy algo
197200
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
198201
}
@@ -203,15 +206,15 @@ mscclpp::Algorithm AllgatherAlgo8::build() {
203206
"default_allgather8", "allgather",
204207
[self](std::shared_ptr<mscclpp::Communicator> comm,
205208
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
206-
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
207-
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
208-
return self->allgatherKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
209-
},
210-
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
211-
return self->initAllgatherContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
209+
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count,
210+
mscclpp::DataType dtype, cudaStream_t stream,
211+
std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
212+
return self->allgatherKernelFunc(ctx, input, output, count, dtype, stream, extras);
212213
},
213-
[self](const void* input, void* output, size_t count, int dtype) {
214-
return self->generateAllgatherContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
214+
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count,
215+
mscclpp::DataType dtype) { return self->initAllgatherContext(comm, input, output, count, dtype); },
216+
[self](const void* input, void* output, size_t count, mscclpp::DataType dtype) {
217+
return self->generateAllgatherContextKey(input, output, count, dtype);
215218
});
216219
return allgatherAlgo;
217220
}

apps/nccl/src/allgather.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <mscclpp/algorithm.hpp>
1010
#include <mscclpp/concurrency_device.hpp>
1111
#include <mscclpp/core.hpp>
12+
#include <mscclpp/executor.hpp>
1213
#include <mscclpp/gpu.hpp>
1314
#include <mscclpp/memory_channel.hpp>
1415
#include <mscclpp/memory_channel_device.hpp>
@@ -222,12 +223,12 @@ class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder {
222223

223224
void initialize(std::shared_ptr<mscclpp::Communicator> comm, std::unordered_map<std::string, std::shared_ptr<void>>&);
224225
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
225-
size_t count, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
226+
size_t count, mscclpp::DataType dtype, cudaStream_t stream,
226227
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
227228

228229
std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
229-
void* output, size_t, ncclDataType_t);
230-
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t);
230+
void* output, size_t, mscclpp::DataType);
231+
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, mscclpp::DataType);
231232
};
232233

233234
class AllgatherAlgo8 : public mscclpp::AlgorithmBuilder {
@@ -240,12 +241,12 @@ class AllgatherAlgo8 : public mscclpp::AlgorithmBuilder {
240241
void initialize(std::shared_ptr<mscclpp::Communicator> comm,
241242
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
242243
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
243-
size_t count, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
244+
size_t count, mscclpp::DataType dtype, cudaStream_t stream,
244245
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
245246

246247
std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
247-
void* output, size_t, ncclDataType_t);
248-
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t);
248+
void* output, size_t, mscclpp::DataType);
249+
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, mscclpp::DataType);
249250

250251
size_t scratchBufferSize_;
251252
std::shared_ptr<char> scratchBuffer_;

0 commit comments

Comments
 (0)