@@ -56,43 +56,41 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
5656 std::shared_ptr<mscclpp::Algorithm> allgatherAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
5757 " allgather" , " allgather" , [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize (comm); },
5858 [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t inputSize,
59- size_t outputSize, int dtype, cudaStream_t stream, std::unordered_map<std::string, uintptr_t >& extras) {
60- return self-> allgatherKernelFunc (ctx, input, output, inputSize, static_cast <ncclDataType_t>(dtype), stream,
61- extras);
59+ size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
60+ std::unordered_map<std::string, uintptr_t >& extras) {
61+ return self-> allgatherKernelFunc (ctx, input, output, inputSize, dtype, stream, extras);
6262 },
6363 [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t inputSize,
64- size_t outputSize, int dtype) {
65- return self->initAllgatherContext (comm, input, output, inputSize, static_cast <ncclDataType_t>(dtype));
66- },
67- [self](const void * input, void * output, size_t inputSize, size_t outputSize, int dtype) {
68- return self->generateAllgatherContextKey (input, output, inputSize, outputSize,
69- static_cast <ncclDataType_t>(dtype));
64+ size_t outputSize,
65+ mscclpp::DataType dtype) { return self->initAllgatherContext (comm, input, output, inputSize, dtype); },
66+ [self](const void * input, void * output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
67+ return self->generateAllgatherContextKey (input, output, inputSize, outputSize, dtype);
7068 });
7169 return allgatherAlgo;
7270 }
7371
7472 private:
75- std::vector<std::shared_ptr< mscclpp::Connection> > conns_;
73+ std::vector<mscclpp::Connection> conns_;
7674 std::shared_ptr<mscclpp::ProxyService> proxyService_;
7775 int worldSize_;
7876
7977 void initialize (std::shared_ptr<mscclpp::Communicator> comm) {
80- std::vector<std::shared_future<std::shared_ptr< mscclpp::Connection> >> connectionFutures;
78+ std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
8179 worldSize_ = comm->bootstrap ()->getNranks ();
8280 for (int i = 0 ; i < worldSize_; i++) {
8381 if (i == comm->bootstrap ()->getRank ()) continue ;
8482 connectionFutures.push_back (comm->connect (mscclpp::Transport::CudaIpc, i));
8583 }
86- std::vector<std::shared_ptr< mscclpp::Connection> > connections;
84+ std::vector<mscclpp::Connection> connections;
8785 std::transform (connectionFutures.begin (), connectionFutures.end (), std::back_inserter (connections),
8886 [](const auto & future) { return future.get (); });
8987 this ->conns_ = std::move (connections);
9088 proxyService_ = std::make_shared<mscclpp::ProxyService>();
91- proxyService_->startProxy ();
89+ proxyService_->startProxy (true );
9290 }
9391
9492 ncclResult_t allgatherKernelFunc (const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output,
95- size_t inputBytes, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
93+ size_t inputBytes, [[maybe_unused]] mscclpp::DataType dtype, cudaStream_t stream,
9694 std::unordered_map<std::string, uintptr_t >& extras) {
9795 int rank = ctx->rank ;
9896 int worldSize = ctx->workSize ;
@@ -107,7 +105,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
107105
108106 std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext (std::shared_ptr<mscclpp::Communicator> comm,
109107 const void * input, void * output, size_t inputBytes,
110- ncclDataType_t dtype) {
108+ mscclpp::DataType dtype) {
111109 auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
112110 ctx->rank = comm->bootstrap ()->getRank ();
113111 ctx->workSize = comm->bootstrap ()->getNranks ();
@@ -149,7 +147,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
149147 }
150148
151149 mscclpp::AlgorithmCtxKey generateAllgatherContextKey (const void * input, void * output, size_t inputSize,
152- size_t outputSize, ncclDataType_t dtype) {
150+ size_t outputSize, mscclpp::DataType dtype) {
153151 return {(void *)input, output, inputSize, outputSize, 0 };
154152 }
155153};
0 commit comments