Skip to content

Commit 8ccaf09

Browse files
committed
WIP
1 parent 9e97afb commit 8ccaf09

File tree

6 files changed

+34
-36
lines changed

6 files changed

+34
-36
lines changed

examples/customized-collective-algorithm/customized_allgather.cu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,14 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
9090
std::shared_ptr<mscclpp::Algorithm> allgatherAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
9191
"allgather", "allgather", [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
9292
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t inputSize,
93-
size_t outputSize, int dtype, cudaStream_t stream, std::unordered_map<std::string, uintptr_t>& extras) {
94-
return self->allgatherKernelFunc(ctx, input, output, inputSize, static_cast<ncclDataType_t>(dtype), stream,
95-
extras);
93+
size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
94+
std::unordered_map<std::string, uintptr_t>& extras) {
95+
return self->allgatherKernelFunc(ctx, input, output, inputSize, dtype, stream, extras);
9696
},
9797
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
98-
size_t outputSize, int dtype) {
99-
return self->initAllgatherContext(comm, input, output, inputSize, static_cast<ncclDataType_t>(dtype));
100-
},
101-
[self](const void* input, void* output, size_t inputSize, size_t outputSize, int dtype) {
98+
size_t outputSize,
99+
mscclpp::DataType dtype) { return self->initAllgatherContext(comm, input, output, inputSize, dtype); },
100+
[self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
102101
return self->generateAllgatherContextKey(input, output, inputSize, outputSize,
103102
static_cast<ncclDataType_t>(dtype));
104103
});
@@ -126,7 +125,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
126125
}
127126

128127
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
129-
size_t inputSize, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
128+
size_t inputSize, [[maybe_unused]] mscclpp::DataType dtype, cudaStream_t stream,
130129
std::unordered_map<std::string, uintptr_t>& extras) {
131130
int rank = ctx->rank;
132131
int worldSize = ctx->workSize;
@@ -141,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
141140

142141
std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
143142
const void* input, void* output, size_t inputSize,
144-
ncclDataType_t dtype) {
143+
mscclpp::DataType dtype) {
145144
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
146145
ctx->rank = comm->bootstrap()->getRank();
147146
ctx->workSize = comm->bootstrap()->getNranks();

examples/torch-integration/customized_allgather.cu

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,43 +56,41 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
5656
std::shared_ptr<mscclpp::Algorithm> allgatherAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
5757
"allgather", "allgather", [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
5858
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t inputSize,
59-
size_t outputSize, int dtype, cudaStream_t stream, std::unordered_map<std::string, uintptr_t>& extras) {
60-
return self->allgatherKernelFunc(ctx, input, output, inputSize, static_cast<ncclDataType_t>(dtype), stream,
61-
extras);
59+
size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
60+
std::unordered_map<std::string, uintptr_t>& extras) {
61+
return self->allgatherKernelFunc(ctx, input, output, inputSize, dtype, stream, extras);
6262
},
6363
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
64-
size_t outputSize, int dtype) {
65-
return self->initAllgatherContext(comm, input, output, inputSize, static_cast<ncclDataType_t>(dtype));
66-
},
67-
[self](const void* input, void* output, size_t inputSize, size_t outputSize, int dtype) {
68-
return self->generateAllgatherContextKey(input, output, inputSize, outputSize,
69-
static_cast<ncclDataType_t>(dtype));
64+
size_t outputSize,
65+
mscclpp::DataType dtype) { return self->initAllgatherContext(comm, input, output, inputSize, dtype); },
66+
[self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
67+
return self->generateAllgatherContextKey(input, output, inputSize, outputSize, dtype);
7068
});
7169
return allgatherAlgo;
7270
}
7371

7472
private:
75-
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
73+
std::vector<mscclpp::Connection> conns_;
7674
std::shared_ptr<mscclpp::ProxyService> proxyService_;
7775
int worldSize_;
7876

7977
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
80-
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
78+
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
8179
worldSize_ = comm->bootstrap()->getNranks();
8280
for (int i = 0; i < worldSize_; i++) {
8381
if (i == comm->bootstrap()->getRank()) continue;
8482
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
8583
}
86-
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
84+
std::vector<mscclpp::Connection> connections;
8785
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
8886
[](const auto& future) { return future.get(); });
8987
this->conns_ = std::move(connections);
9088
proxyService_ = std::make_shared<mscclpp::ProxyService>();
91-
proxyService_->startProxy();
89+
proxyService_->startProxy(true);
9290
}
9391

9492
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
95-
size_t inputBytes, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
93+
size_t inputBytes, [[maybe_unused]] mscclpp::DataType dtype, cudaStream_t stream,
9694
std::unordered_map<std::string, uintptr_t>& extras) {
9795
int rank = ctx->rank;
9896
int worldSize = ctx->workSize;
@@ -107,7 +105,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
107105

108106
std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
109107
const void* input, void* output, size_t inputBytes,
110-
ncclDataType_t dtype) {
108+
mscclpp::DataType dtype) {
111109
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
112110
ctx->rank = comm->bootstrap()->getRank();
113111
ctx->workSize = comm->bootstrap()->getNranks();
@@ -149,7 +147,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
149147
}
150148

151149
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void* input, void* output, size_t inputSize,
152-
size_t outputSize, ncclDataType_t dtype) {
150+
size_t outputSize, mscclpp::DataType dtype) {
153151
return {(void*)input, output, inputSize, outputSize, 0};
154152
}
155153
};

examples/torch-integration/dsl_with_nccl_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
# LD_PRELOAD=<MSCCLPP_REPO>/build/apps/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 torch-integration/dsl_with_nccl_api.py
4+
# LD_PRELOAD=<MSCCLPP_REPO>/build/apps/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
55

66
import os
77
from typing import Any, Dict
@@ -111,6 +111,7 @@ def main():
111111
dist.all_reduce(x, op=dist.ReduceOp.SUM)
112112
dist.barrier()
113113
dist.destroy_process_group()
114+
print(f"Rank {local} allreduce completed successfully.")
114115

115116

116117
if __name__ == "__main__":

python/csrc/core_py.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ void def_shared_future(nb::handle& m, const std::string& typestr) {
3636
void register_core(nb::module_& m) {
3737
m.def("version", &version);
3838

39+
nb::enum_<DataType>(m, "DataType")
40+
.value("int32", DataType::INT32)
41+
.value("uint32", DataType::UINT32)
42+
.value("float16", DataType::FLOAT16)
43+
.value("float32", DataType::FLOAT32)
44+
.value("bfloat16", DataType::BFLOAT16);
45+
3946
nb::class_<Bootstrap>(m, "Bootstrap")
4047
.def("get_rank", &Bootstrap::getRank)
4148
.def("get_n_ranks", &Bootstrap::getNranks)

python/csrc/executor_py.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@ namespace nb = nanobind;
1515
using namespace mscclpp;
1616

1717
void register_executor(nb::module_& m) {
18-
nb::enum_<DataType>(m, "DataType")
19-
.value("int32", DataType::INT32)
20-
.value("uint32", DataType::UINT32)
21-
.value("float16", DataType::FLOAT16)
22-
.value("float32", DataType::FLOAT32)
23-
.value("bfloat16", DataType::BFLOAT16);
24-
2518
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
2619

2720

python/mscclpp/_algorithm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
AlgorithmCollectionBuilder as _AlgorithmCollectionBuilder,
1515
Communicator,
1616
CollectiveBufferMode,
17-
DeviceType,
17+
DataType,
1818
Executor,
1919
ExecutionPlan,
2020
)
@@ -100,7 +100,7 @@ def execute(
100100
output_buffer: int,
101101
input_size: int,
102102
output_size: int,
103-
dtype: DeviceType,
103+
dtype: DataType,
104104
stream: int,
105105
executor: Optional[Executor] = None,
106106
extras: Optional[Dict[str, int]] = None,
@@ -111,7 +111,7 @@ def execute(
111111
int(output_buffer),
112112
input_size,
113113
output_size,
114-
int(dtype),
114+
dtype,
115115
int(stream),
116116
executor,
117117
extras if extras is not None else {}

0 commit comments

Comments
 (0)