88#include < mscclpp/gpu_utils.hpp>
99
1010#include " allgather.hpp"
11+ #include " datatype_conversion.hpp"
1112#include " debug.h"
1213
1314AllgatherAlgo6::AllgatherAlgo6 () : disableChannelCache_(false ) {
@@ -23,10 +24,11 @@ void AllgatherAlgo6::initialize(std::shared_ptr<mscclpp::Communicator> comm,
2324}
2425
2526ncclResult_t AllgatherAlgo6::allgatherKernelFunc (const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input,
26- void * output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
27+ void * output, size_t count, mscclpp::DataType dtype,
28+ cudaStream_t stream,
2729 std::unordered_map<std::string, std::shared_ptr<void >>&) {
2830 int nBlocks = 28 ;
29- const size_t bytes = count * ncclTypeSize (dtype);
31+ const size_t bytes = count * getDataTypeSize (dtype);
3032 const size_t nElem = bytes / sizeof (int );
3133 int rank = ctx->rank ;
3234 if (bytes <= 32 * (1 << 20 )) {
@@ -61,7 +63,7 @@ ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::
6163
6264std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext (std::shared_ptr<mscclpp::Communicator> comm,
6365 const void *, void * output, size_t count,
64- ncclDataType_t dtype) {
66+ mscclpp::DataType dtype) {
6567 auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
6668 ctx->rank = comm->bootstrap ()->getRank ();
6769 ctx->workSize = comm->bootstrap ()->getNranks ();
@@ -70,7 +72,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std:
7072 // setup semaphores
7173 ctx->memorySemaphores = this ->memorySemaphores_ ;
7274
73- size_t bytes = count * ncclTypeSize (dtype);
75+ size_t bytes = count * getDataTypeSize (dtype);
7476 size_t recvBytes;
7577 CUdeviceptr recvBasePtr;
7678 MSCCLPP_CUTHROW (cuMemGetAddressRange (&recvBasePtr, &recvBytes, (CUdeviceptr)output));
@@ -98,7 +100,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std:
98100}
99101
100102mscclpp::AlgorithmCtxKey AllgatherAlgo6::generateAllgatherContextKey (const void *, void * output, size_t ,
101- ncclDataType_t ) {
103+ mscclpp::DataType ) {
102104 static int tag = 0 ;
103105 if (disableChannelCache_) {
104106 // always return a new key if channel cache is disabled
@@ -116,15 +118,15 @@ mscclpp::Algorithm AllgatherAlgo6::build() {
116118 " default_allgather6" , " allgather" ,
117119 [self](std::shared_ptr<mscclpp::Communicator> comm,
118120 std::unordered_map<std::string, std::shared_ptr<void >>& extras) { self->initialize (comm, extras); },
119- [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t count, int dtype,
120- cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
121- return self->allgatherKernelFunc (ctx, input, output, count, static_cast <ncclDataType_t>(dtype), stream, extras);
121+ [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t count,
122+ mscclpp::DataType dtype, cudaStream_t stream,
123+ std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
124+ return self->allgatherKernelFunc (ctx, input, output, count, dtype, stream, extras);
122125 },
123- [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t count, int dtype) {
124- return self->initAllgatherContext (comm, input, output, count, static_cast <ncclDataType_t>(dtype));
125- },
126- [self](const void * input, void * output, size_t count, int dtype) {
127- return self->generateAllgatherContextKey (input, output, count, static_cast <ncclDataType_t>(dtype));
126+ [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t count,
127+ mscclpp::DataType dtype) { return self->initAllgatherContext (comm, input, output, count, dtype); },
128+ [self](const void * input, void * output, size_t count, mscclpp::DataType dtype) {
129+ return self->generateAllgatherContextKey (input, output, count, dtype);
128130 });
129131 return allgatherAlgo;
130132}
@@ -137,10 +139,11 @@ void AllgatherAlgo8::initialize(std::shared_ptr<mscclpp::Communicator> comm,
137139}
138140
139141ncclResult_t AllgatherAlgo8::allgatherKernelFunc (const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input,
140- void * output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
142+ void * output, size_t count, mscclpp::DataType dtype,
143+ cudaStream_t stream,
141144 std::unordered_map<std::string, std::shared_ptr<void >>&) {
142145 int rank = ctx->rank ;
143- const size_t bytes = count * ncclTypeSize (dtype);
146+ const size_t bytes = count * getDataTypeSize (dtype);
144147 const size_t nElem = bytes / sizeof (int );
145148 if ((char *)input == (char *)output + rank * bytes) {
146149 allgather8<false ><<<56 , 1024 , 0 , stream>>> ((void *)input, this ->scratchBuffer_ .get (), (void *)output,
@@ -161,7 +164,7 @@ ncclResult_t AllgatherAlgo8::allgatherKernelFunc(const std::shared_ptr<mscclpp::
161164
162165std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext (std::shared_ptr<mscclpp::Communicator> comm,
163166 const void * input, void *, size_t count,
164- ncclDataType_t dtype) {
167+ mscclpp::DataType dtype) {
165168 constexpr int nChannelsPerConnection = 56 ;
166169
167170 auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
@@ -172,7 +175,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std:
172175 // setup semaphores
173176 ctx->memorySemaphores = std::move (setupMemorySemaphores (comm, this ->conns_ , nChannelsPerConnection));
174177
175- size_t bytes = count * ncclTypeSize (dtype);
178+ size_t bytes = count * getDataTypeSize (dtype);
176179 // register the memory for the broadcast operation
177180 mscclpp::RegisteredMemory localMemory = comm->registerMemory ((void *)input, bytes, mscclpp::Transport::CudaIpc);
178181 mscclpp::RegisteredMemory scratchMemory =
@@ -192,7 +195,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std:
192195 return ctx;
193196}
194197
195- mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey (const void *, void *, size_t , ncclDataType_t ) {
198+ mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey (const void *, void *, size_t , mscclpp::DataType ) {
196199 // always return same key, non-zero copy algo
197200 return mscclpp::AlgorithmCtxKey{nullptr , nullptr , 0 , 0 , 0 };
198201}
@@ -203,15 +206,15 @@ mscclpp::Algorithm AllgatherAlgo8::build() {
203206 " default_allgather8" , " allgather" ,
204207 [self](std::shared_ptr<mscclpp::Communicator> comm,
205208 std::unordered_map<std::string, std::shared_ptr<void >>& extras) { self->initialize (comm, extras); },
206- [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t count, int dtype,
207- cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
208- return self->allgatherKernelFunc (ctx, input, output, count, static_cast <ncclDataType_t>(dtype), stream, extras);
209- },
210- [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t count, int dtype) {
211- return self->initAllgatherContext (comm, input, output, count, static_cast <ncclDataType_t>(dtype));
209+ [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t count,
210+ mscclpp::DataType dtype, cudaStream_t stream,
211+ std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
212+ return self->allgatherKernelFunc (ctx, input, output, count, dtype, stream, extras);
212213 },
213- [self](const void * input, void * output, size_t count, int dtype) {
214- return self->generateAllgatherContextKey (input, output, count, static_cast <ncclDataType_t>(dtype));
214+ [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t count,
215+ mscclpp::DataType dtype) { return self->initAllgatherContext (comm, input, output, count, dtype); },
216+ [self](const void * input, void * output, size_t count, mscclpp::DataType dtype) {
217+ return self->generateAllgatherContextKey (input, output, count, dtype);
215218 });
216219 return allgatherAlgo;
217220}
0 commit comments