|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_default_algo.py |
| 5 | + |
| 6 | +import os |
| 7 | +import torch |
| 8 | +import mscclpp.comm as mscclpp_comm |
| 9 | +import mscclpp |
| 10 | +import netifaces as ni |
| 11 | +import ipaddress |
| 12 | +import ctypes |
| 13 | + |
| 14 | + |
| 15 | +def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection: |
| 16 | + collection_builder = mscclpp.AlgorithmCollectionBuilder() |
| 17 | + return collection_builder.build_default_algorithms( |
| 18 | + scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank |
| 19 | + ) |
| 20 | + |
| 21 | + |
| 22 | +def interfaces_for_ip_netifaces(ip: str): |
| 23 | + target = ipaddress.ip_address(ip) |
| 24 | + for interface in ni.interfaces(): |
| 25 | + addresses = ni.ifaddresses(interface) |
| 26 | + if ni.AF_INET in addresses: |
| 27 | + for link in addresses[ni.AF_INET]: |
| 28 | + if "addr" in link: |
| 29 | + addr = ipaddress.ip_address(link["addr"]) |
| 30 | + if addr == target: |
| 31 | + return interface |
| 32 | + return None |
| 33 | + |
| 34 | + |
| 35 | +def dtype_to_mscclpp_dtype(dtype: torch.dtype) -> mscclpp.DataType: |
| 36 | + if dtype == torch.float16: |
| 37 | + return mscclpp.DataType.float16 |
| 38 | + elif dtype == torch.float32: |
| 39 | + return mscclpp.DataType.float32 |
| 40 | + elif dtype == torch.int32: |
| 41 | + return mscclpp.DataType.int32 |
| 42 | + elif dtype == torch.bfloat16: |
| 43 | + return mscclpp.DataType.bfloat16 |
| 44 | + else: |
| 45 | + raise ValueError(f"Unknown data type: {dtype}") |
| 46 | + |
| 47 | + |
| 48 | +class CustomizedComm: |
| 49 | + def __init__(self, comm: mscclpp_comm.CommGroup): |
| 50 | + self.comm = comm |
| 51 | + self.rank = comm.my_rank |
| 52 | + self.world_size = comm.nranks |
| 53 | + self.local_rank = comm.my_rank % comm.nranks_per_node |
| 54 | + self.n_ranks_per_node = comm.nranks_per_node |
| 55 | + dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16)) |
| 56 | + self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack) |
| 57 | + algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank) |
| 58 | + self._algorithm_nvls_packet = [ |
| 59 | + algo |
| 60 | + for algo in algorithms |
| 61 | + if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet" |
| 62 | + ][0] |
| 63 | + self._algorithm_nvls_nonzero_copy = [ |
| 64 | + algo |
| 65 | + for algo in algorithms |
| 66 | + if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_with_copy" |
| 67 | + ][0] |
| 68 | + |
| 69 | + def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None): |
| 70 | + assert op == torch.distributed.ReduceOp.SUM |
| 71 | + algo = None |
| 72 | + if tensor.nbytes < 1 << 20: |
| 73 | + algo = self._algorithm_nvls_packet |
| 74 | + else: |
| 75 | + algo = self._algorithm_nvls_nonzero_copy |
| 76 | + ctype_op = ctypes.c_int32(op.value) |
| 77 | + extras: dict[str, int] = {"op": ctypes.addressof(ctype_op)} |
| 78 | + algo.execute( |
| 79 | + comm=self.comm.communicator, |
| 80 | + input_buffer=tensor.data_ptr(), |
| 81 | + output_buffer=tensor.data_ptr(), |
| 82 | + input_size=tensor.nbytes, |
| 83 | + output_size=tensor.nbytes, |
| 84 | + dtype=dtype_to_mscclpp_dtype(tensor.dtype), |
| 85 | + stream=stream.cuda_stream if stream is not None else 0, |
| 86 | + extras=extras, |
| 87 | + ) |
| 88 | + |
| 89 | + def barrier(self): |
| 90 | + tensor = torch.empty(1, dtype=torch.float, device=torch.device("cuda")) |
| 91 | + self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream()) |
| 92 | + |
| 93 | + def destroy(self): |
| 94 | + self.executor = None |
| 95 | + self._algorithm_nvls_nonzero_copy = None |
| 96 | + self._algorithm_nvls_packet = None |
| 97 | + self.scratch_buffer = None |
| 98 | + self.comm = None |
| 99 | + |
| 100 | + |
| 101 | +def init_dist() -> CustomizedComm: |
| 102 | + rank = int(os.environ["RANK"]) |
| 103 | + world = int(os.environ["WORLD_SIZE"]) |
| 104 | + master_addr = os.environ["MSCCLPP_MASTER_ADDR"] |
| 105 | + master_port = os.environ["MSCCLPP_MASTER_PORT"] |
| 106 | + interface = interfaces_for_ip_netifaces(master_addr) |
| 107 | + if interface is None: |
| 108 | + raise ValueError(f"Cannot find network interface for IP address {master_addr}") |
| 109 | + nranks_per_node = os.environ.get("MSCCLPP_NRANKS_PER_NODE") |
| 110 | + if nranks_per_node is None: |
| 111 | + nranks_per_node = os.environ.get("LOCAL_WORLD_SIZE") |
| 112 | + if nranks_per_node is None: |
| 113 | + nnodes = int(os.environ.get("NNODES", "1")) |
| 114 | + if world % nnodes == 0: |
| 115 | + nranks_per_node = world // nnodes |
| 116 | + if nranks_per_node is None: |
| 117 | + nranks_per_node = torch.cuda.device_count() |
| 118 | + nranks_per_node = int(nranks_per_node) |
| 119 | + nranks_per_node = max(1, min(world, nranks_per_node)) |
| 120 | + interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}" |
| 121 | + mscclpp_group = mscclpp_comm.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world) |
| 122 | + return CustomizedComm(mscclpp_group) |
| 123 | + |
| 124 | + |
| 125 | +def main(): |
| 126 | + local = int(os.environ["LOCAL_RANK"]) |
| 127 | + torch.cuda.set_device(local) |
| 128 | + comm = init_dist() |
| 129 | + comm.barrier() |
| 130 | + input_data = torch.randn(1 << 22, dtype=torch.float16, device=torch.device("cuda")) |
| 131 | + comm.all_reduce(input_data, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream()) |
| 132 | + comm.barrier() |
| 133 | + comm.destroy() |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == "__main__": |
| 137 | + main() |
0 commit comments