Skip to content

Commit cef1f89

Browse files
committed
update
1 parent e2a5d69 commit cef1f89

File tree

5 files changed

+190
-4
lines changed

5 files changed

+190
-4
lines changed
File renamed without changes.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_default_algo.py
5+
6+
import os
7+
import torch
8+
import mscclpp.comm as mscclpp_comm
9+
import mscclpp
10+
import netifaces as ni
11+
import ipaddress
12+
import ctypes
13+
14+
15+
def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection:
16+
collection_builder = mscclpp.AlgorithmCollectionBuilder()
17+
return collection_builder.build_default_algorithms(
18+
scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank
19+
)
20+
21+
22+
def interfaces_for_ip_netifaces(ip: str):
23+
target = ipaddress.ip_address(ip)
24+
for interface in ni.interfaces():
25+
addresses = ni.ifaddresses(interface)
26+
if ni.AF_INET in addresses:
27+
for link in addresses[ni.AF_INET]:
28+
if "addr" in link:
29+
addr = ipaddress.ip_address(link["addr"])
30+
if addr == target:
31+
return interface
32+
return None
33+
34+
35+
def dtype_to_mscclpp_dtype(dtype: torch.dtype) -> mscclpp.DataType:
36+
if dtype == torch.float16:
37+
return mscclpp.DataType.float16
38+
elif dtype == torch.float32:
39+
return mscclpp.DataType.float32
40+
elif dtype == torch.int32:
41+
return mscclpp.DataType.int32
42+
elif dtype == torch.bfloat16:
43+
return mscclpp.DataType.bfloat16
44+
else:
45+
raise ValueError(f"Unknown data type: {dtype}")
46+
47+
48+
class CustomizedComm:
49+
def __init__(self, comm: mscclpp_comm.CommGroup):
50+
self.comm = comm
51+
self.rank = comm.my_rank
52+
self.world_size = comm.nranks
53+
self.local_rank = comm.my_rank % comm.nranks_per_node
54+
self.n_ranks_per_node = comm.nranks_per_node
55+
dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
56+
self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack)
57+
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
58+
self._algorithm_nvls_packet = [
59+
algo
60+
for algo in algorithms
61+
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
62+
][0]
63+
self._algorithm_nvls_nonzero_copy = [
64+
algo
65+
for algo in algorithms
66+
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_with_copy"
67+
][0]
68+
69+
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
70+
assert op == torch.distributed.ReduceOp.SUM
71+
algo = None
72+
if tensor.nbytes < 1 << 20:
73+
algo = self._algorithm_nvls_packet
74+
else:
75+
algo = self._algorithm_nvls_nonzero_copy
76+
ctype_op = ctypes.c_int32(op.value)
77+
extras: dict[str, int] = {"op": ctypes.addressof(ctype_op)}
78+
algo.execute(
79+
comm=self.comm.communicator,
80+
input_buffer=tensor.data_ptr(),
81+
output_buffer=tensor.data_ptr(),
82+
input_size=tensor.nbytes,
83+
output_size=tensor.nbytes,
84+
dtype=dtype_to_mscclpp_dtype(tensor.dtype),
85+
stream=stream.cuda_stream if stream is not None else 0,
86+
extras=extras,
87+
)
88+
89+
def barrier(self):
90+
tensor = torch.empty(1, dtype=torch.float, device=torch.device("cuda"))
91+
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
92+
93+
def destroy(self):
94+
self.executor = None
95+
self._algorithm_nvls_nonzero_copy = None
96+
self._algorithm_nvls_packet = None
97+
self.scratch_buffer = None
98+
self.comm = None
99+
100+
101+
def init_dist() -> CustomizedComm:
102+
rank = int(os.environ["RANK"])
103+
world = int(os.environ["WORLD_SIZE"])
104+
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
105+
master_port = os.environ["MSCCLPP_MASTER_PORT"]
106+
interface = interfaces_for_ip_netifaces(master_addr)
107+
if interface is None:
108+
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
109+
nranks_per_node = os.environ.get("MSCCLPP_NRANKS_PER_NODE")
110+
if nranks_per_node is None:
111+
nranks_per_node = os.environ.get("LOCAL_WORLD_SIZE")
112+
if nranks_per_node is None:
113+
nnodes = int(os.environ.get("NNODES", "1"))
114+
if world % nnodes == 0:
115+
nranks_per_node = world // nnodes
116+
if nranks_per_node is None:
117+
nranks_per_node = torch.cuda.device_count()
118+
nranks_per_node = int(nranks_per_node)
119+
nranks_per_node = max(1, min(world, nranks_per_node))
120+
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
121+
mscclpp_group = mscclpp_comm.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
122+
return CustomizedComm(mscclpp_group)
123+
124+
125+
def main():
126+
local = int(os.environ["LOCAL_RANK"])
127+
torch.cuda.set_device(local)
128+
comm = init_dist()
129+
comm.barrier()
130+
input_data = torch.randn(1 << 22, dtype=torch.float16, device=torch.device("cuda"))
131+
comm.all_reduce(input_data, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
132+
comm.barrier()
133+
comm.destroy()
134+
135+
136+
if __name__ == "__main__":
137+
main()

python/csrc/algorithm.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ void register_algorithm(nb::module_& m) {
2222

2323
nb::enum_<AlgorithmType>(m, "AlgorithmType").value("NATIVE", AlgorithmType::NATIVE).value("DSL", AlgorithmType::DSL);
2424

25+
nb::enum_<CommResult>(m, "CommResult")
26+
.value("COMM_SUCCESS", CommResult::commSuccess)
27+
.value("COMM_UNHANDLED_CUDA_ERROR", CommResult::commUnhandledCudaError)
28+
.value("COMM_SYSTEM_ERROR", CommResult::commSystemError)
29+
.value("COMM_INTERNAL_ERROR", CommResult::commInternalError)
30+
.value("COMM_INVALID_ARGUMENT", CommResult::commInvalidArgument)
31+
.value("COMM_INVALID_USAGE", CommResult::commInvalidUsage)
32+
.value("COMM_REMOTE_ERROR", CommResult::commRemoteError)
33+
.value("COMM_IN_PROGRESS", CommResult::commInProgress)
34+
.value("COMM_NUM_RESULTS", CommResult::commNumResults);
35+
2536
auto algorithmClass =
2637
nb::class_<Algorithm>(m, "Algorithm")
2738
.def_static(

python/mscclpp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from functools import wraps
1212
from mscclpp._version import __version__, __commit_id__
13-
from mscclpp._algorithm import Algorithm, AlgorithmCollectionBuilder
13+
from mscclpp._algorithm import Algorithm, AlgorithmCollectionBuilder, AlgorithmCollection
1414
from mscclpp.language.utils import AlgoSpec
1515
from mscclpp._compiler import DslCompiler, NativeCodeCompiler
1616

@@ -91,6 +91,7 @@
9191
# Python API
9292
"Algorithm",
9393
"AlgorithmCollectionBuilder",
94+
"AlgorithmCollection",
9495
"AlgoSpec",
9596
]
9697

python/mscclpp/_algorithm.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DslAlgorithm as _DslAlgorithm,
1212
AlgorithmType as _AlgorithmType,
1313
AlgorithmBuilder as _AlgorithmBuilder,
14+
AlgorithmCollection as _AlgorithmCollection,
1415
AlgorithmCollectionBuilder as _AlgorithmCollectionBuilder,
1516
Communicator,
1617
CollectiveBufferMode,
@@ -126,6 +127,36 @@ def build(self) -> Algorithm:
126127
return Algorithm.create_from_native_handle(self._algorithm_builder.build())
127128

128129

130+
class AlgorithmCollection:
131+
def __init__(self, native_collection: _AlgorithmCollection):
132+
self._native_collection = native_collection
133+
self._algorithms = [
134+
Algorithm.create_from_native_handle(algo)
135+
for algo in self._native_collection.to_list()
136+
]
137+
138+
def __iter__(self):
139+
"""Iterate over all algorithms in the collection."""
140+
return iter(self._algorithms)
141+
142+
def __len__(self):
143+
"""Return the number of algorithms in the collection."""
144+
return len(self._algorithms)
145+
146+
def __getitem__(self, index: int) -> Algorithm:
147+
"""Get an algorithm by index."""
148+
return self._algorithms[index]
149+
150+
def get_by_collective(self, collective: str):
151+
"""Get all algorithms for a specific collective operation."""
152+
return [algo for algo in self._algorithms if algo.collective == collective]
153+
154+
def register_algorithm(self, collective: str, algo_name: str, algorithm: Algorithm):
155+
"""Register an algorithm for a collective operation."""
156+
self._native_collection.register_algorithm(collective, algo_name, algorithm._algorithm)
157+
self._algorithms.append(algorithm)
158+
159+
129160
class AlgorithmCollectionBuilder:
130161
_instance = None
131162

@@ -161,8 +192,14 @@ def set_algorithm_selector(self, selector):
161192
def set_fallback_algorithm_selector(self, selector):
162193
self._builder.set_fallback_algorithm_selector(selector)
163194

164-
def build(self):
165-
return self._builder.build()
166-
195+
def build(self) -> AlgorithmCollection:
196+
collection = self._builder.build()
197+
return AlgorithmCollection(collection)
198+
199+
def build_default_algorithms(self, scratch_buffer: int, scratch_buffer_size: int, rank: int) -> AlgorithmCollection:
200+
native_collection = self._builder.build_default_algorithms(
201+
int(scratch_buffer), scratch_buffer_size, rank
202+
)
203+
return AlgorithmCollection(native_collection)
167204

168205
atexit.register(AlgorithmCollectionBuilder.reset)

0 commit comments

Comments
 (0)