Skip to content

Commit b9f029b

Browse files
committed
create_tensor/free_tensor with tensor_id
1 parent feea94d commit b9f029b

File tree

7 files changed

+226
-234
lines changed

7 files changed

+226
-234
lines changed

p2p/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ help:
9595
@echo " install-deps - Install pybind11 dependency"
9696
@echo " help - Show this help message"
9797

98-
.PHONY: all clean test install-deps help install
98+
.PHONY: all clean test install-deps help install

p2p/MakefileHip

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,10 @@ PYTHON_LDFLAGS := $(shell $(PYTHON_CONFIG) --ldflags)
2222
PYTHON_SITE_PACKAGES := $(shell $(PYTHON) -c "import site; print(site.getsitepackages()[0])")
2323
INSTALL_DIR := $(PYTHON_SITE_PACKAGES)/uccl
2424

25-
# Pytorch
26-
TORCH_PATH := $(shell $(PYTHON) -c "import torch; print(torch.__path__[0])")
27-
TORCH_INC := -I$(TORCH_PATH)/include -I$(TORCH_PATH)/include/torch/csrc/api/include
28-
TORCH_LIB := -L$(TORCH_PATH)/lib -ltorch -ltorch_cpu -ltorch_python -ltorch_hip -lc10_hip
29-
30-
CXXFLAGS += -D__HIP_PLATFORM_AMD__ $(TORCH_INC)
25+
CXXFLAGS += -D__HIP_PLATFORM_AMD__
3126
LDFLAGS = -L$(HIP_LIB) -lamdhip64 \
3227
-Wl,-rpath,$(HIP_LIB) -I${CONDA_LIB_HOME}/../include -L${CONDA_LIB_HOME} -lglog -lgflags -lgtest \
33-
-lz -lelf -libverbs -lpthread $(TORCH_LIB) -Wl,-rpath,$(TORCH_PATH)/lib
28+
-lz -lelf -libverbs -lpthread
3429

3530
# Target and source files
3631
TARGET := p2p$(PYEXT)

p2p/pybind_engine.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@ PYBIND11_MODULE(p2p, m) {
1111

1212
m.def("get_oob_ip", &get_oob_ip, "Get the OOB IP address");
1313

14-
m.def("create_tensor", [](int gpu_index, size_t num_elems, size_t dtype_size, bool requires_grad = false) {
15-
uint64_t mr_id, ipc_id;
16-
auto tensor = create_tensor(gpu_index, num_elems, dtype_size, mr_id, ipc_id, requires_grad);
17-
return std::make_tuple(tensor, mr_id, ipc_id);
18-
}, "Create a tensor with RDMA capabilities",
19-
py::arg("gpu_index"), py::arg("num_elems"), py::arg("dtype_size"), py::arg("requires_grad") = false);
14+
m.def(
15+
"reg_mem",
16+
[](int gpu_id, uint64_t addr, size_t size) {
17+
uint64_t mem_id;
18+
reg_mem(gpu_id, reinterpret_cast<void*>(addr), size, mem_id);
19+
return mem_id;
20+
},
21+
"Reg the memory with RDMA capabilities", py::arg("gpu_id"),
22+
py::arg("addr"), py::arg("size"));
2023

21-
m.def("free_tensor", [](torch::Tensor& tensor, uint64_t mr_id, uint64_t ipc_id) {
22-
free_tensor(tensor, mr_id, ipc_id);
23-
}, "Free the tensor and associated RDMA resources",
24-
py::arg("tensor"), py::arg("mr_id"), py::arg("ipc_id"));
24+
m.def(
25+
"dereg_mem",
26+
[](uint64_t mem_id) { dereg_mem(mem_id); },
27+
"Dereg the memory associated RDMA resources", py::arg("mem_id"));
2528

2629
// Endpoint class binding
2730
py::class_<Endpoint>(m, "Endpoint")

p2p/tensor.cc

Lines changed: 123 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,205 +1,161 @@
11
#include "tensor.h"
22
#include <cerrno>
3-
#include <stdexcept>
43
#include <iostream>
4+
#include <stdexcept>
55

66
std::shared_mutex mr_mapping_mu_;
77
std::unordered_map<uint64_t, std::unique_ptr<MR>> mr_mapping_;
8-
std::atomic<uint64_t> next_mr_id_{0};
98

109
std::shared_mutex ipc_handle_mapping_mu_;
1110
std::unordered_map<uint64_t, std::unique_ptr<IPCMemHandle>> ipc_handle_mapping_;
12-
std::atomic<uint64_t> next_ipc_id_{0};
1311

14-
int reg_dma_mr(uccl::FactoryDevice* dev, void* addr, size_t len, int type, int offset,
12+
std::atomic<uint64_t> next_mem_id_{0};
13+
14+
int reg_dma_mr(uccl::FactoryDevice* dev, void* addr, size_t len, int offset,
1515
int fd, struct uccl::Mhandle** mhandle) {
16-
bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable();
16+
bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable();
1717

18-
unsigned int flags =
19-
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
20-
if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
18+
unsigned int flags =
19+
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
20+
if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
2121

22-
*mhandle = new uccl::Mhandle();
23-
(*mhandle)->mr = ibv_reg_dmabuf_mr(dev->pd, offset, len,
24-
(uint64_t)addr, fd, flags);
25-
return 0;
22+
*mhandle = new uccl::Mhandle();
23+
(*mhandle)->mr =
24+
ibv_reg_dmabuf_mr(dev->pd, offset, len, (uint64_t)addr, fd, flags);
25+
return 0;
2626
}
2727

28-
int reg_mr(uccl::FactoryDevice* dev, void* addr, size_t len, struct uccl::Mhandle** mhandle) {
29-
bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable();
30-
31-
unsigned int flags =
32-
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
33-
if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
34-
35-
*mhandle = new uccl::Mhandle();
36-
if (ib_relaxed_ordering_enabled_) {
37-
(*mhandle)->mr =
38-
ibv_reg_mr_iova2(dev->pd, addr, len, (uint64_t)addr, flags);
39-
} else {
40-
(*mhandle)->mr = ibv_reg_mr(dev->pd, addr, len, flags);
41-
}
42-
if (!(*mhandle)->mr) {
43-
std::cerr << "ibv_reg_mr failed (" << strerror(errno) << "), len=" << len
44-
<< " addr=" << addr << "\n";
45-
delete *mhandle;
46-
*mhandle = nullptr;
47-
return -1;
48-
}
49-
return 0;
28+
int reg_mr(uccl::FactoryDevice* dev, void* addr, size_t len,
29+
struct uccl::Mhandle** mhandle) {
30+
bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable();
31+
32+
unsigned int flags =
33+
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
34+
if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
35+
36+
*mhandle = new uccl::Mhandle();
37+
if (ib_relaxed_ordering_enabled_) {
38+
(*mhandle)->mr =
39+
ibv_reg_mr_iova2(dev->pd, addr, len, (uint64_t)addr, flags);
40+
} else {
41+
(*mhandle)->mr = ibv_reg_mr(dev->pd, addr, len, flags);
42+
}
43+
if (!(*mhandle)->mr) {
44+
std::cerr << "ibv_reg_mr failed (" << strerror(errno) << "), len=" << len
45+
<< " addr=" << addr << "\n";
46+
delete *mhandle;
47+
*mhandle = nullptr;
48+
return -1;
49+
}
50+
return 0;
5051
}
5152

5253
void dereg_mr(struct uccl::Mhandle* mhandle) {
53-
ibv_dereg_mr(mhandle->mr);
54-
delete mhandle;
54+
ibv_dereg_mr(mhandle->mr);
55+
delete mhandle;
5556
}
5657

5758
int get_ipc_handle(void* addr, struct IPCMemHandle* ipchandle) {
58-
GPU_RT_CHECK(gpuIpcGetMemHandle(&ipchandle->handle, reinterpret_cast<void*>(addr)));
59-
return 0;
59+
GPU_RT_CHECK(
60+
gpuIpcGetMemHandle(&ipchandle->handle, reinterpret_cast<void*>(addr)));
61+
return 0;
6062
}
6163

6264
int open_ipc_handle(void* addr, struct IPCMemHandle* ipchandle) {
63-
GPU_RT_CHECK(gpuIpcOpenMemHandle(&addr, ipchandle->handle,
65+
GPU_RT_CHECK(gpuIpcOpenMemHandle(&addr, ipchandle->handle,
6466
gpuIpcMemLazyEnablePeerAccess));
65-
return 0;
67+
return 0;
6668
}
6769

68-
torch::Dtype torch_dtype_from_size(size_t dtype_size) {
69-
switch (dtype_size) {
70-
case 1:
71-
return torch::kInt8;
72-
case 2:
73-
return torch::kInt16;
74-
case 4:
75-
return torch::kInt32;
76-
case 8:
77-
return torch::kInt64;
78-
default:
79-
throw std::runtime_error("Unsupported dtype size: " +
80-
std::to_string(dtype_size));
81-
}
82-
}
83-
84-
torch::Tensor create_tensor(int gpu_index, size_t num_elems, size_t dtype_size,
85-
uint64_t& mr_id, uint64_t& ipc_id, bool requires_grad) {
86-
std::cout << "[create_tensor] gpu_index=" << gpu_index
87-
<< " num_elems=" << num_elems
88-
<< " dtype_size=" << dtype_size
89-
<< " requires_grad=" << requires_grad << std::endl;
90-
91-
GPU_RT_CHECK(gpuSetDevice(gpu_index));
92-
uccl::FactoryDevice* factory_dev = uccl::RDMAFactory::get_factory_dev(gpu_to_dev[gpu_index]);
93-
std::cout << "[create_tensor] Got factory_dev for gpu_index " << gpu_index << std::endl;
94-
95-
size_t bytes = num_elems * dtype_size;
96-
size_t alignment = kIpcAlignment;
97-
std::cout << "[create_tensor] Allocating " << bytes << " bytes (aligned to " << alignment << ")" << std::endl;
98-
99-
void* raw_ptr;
100-
GPU_RT_CHECK(gpuMalloc(&raw_ptr, bytes + alignment));
101-
std::cout << "[create_tensor] gpuMalloc success, raw_ptr=" << raw_ptr << std::endl;
102-
103-
uintptr_t aligned_addr = (reinterpret_cast<uintptr_t>(raw_ptr) + alignment - 1) & ~(alignment - 1);
104-
void* aligned_ptr = reinterpret_cast<void*>(aligned_addr);
105-
std::cout << "[create_tensor] Aligned pointer=" << aligned_ptr << std::endl;
106-
107-
// Tensor
108-
auto dtype_ = torch_dtype_from_size(dtype_size);
109-
auto dev = torch::Device(torch_dev, gpu_index);
110-
auto options = torch::TensorOptions().dtype(dtype_).device(dev).requires_grad(requires_grad);
111-
auto deleter = [raw_ptr](void* ptr) {
112-
std::cout << "[create_tensor] Deleter freeing raw_ptr=" << raw_ptr << std::endl;
113-
GPU_RT_CHECK(gpuFree(raw_ptr));
114-
};
115-
torch::Tensor tensor = torch::from_blob(aligned_ptr, {static_cast<long>(num_elems)}, deleter, options);
116-
std::cout << "[create_tensor] Torch tensor created: sizes=" << tensor.sizes() << std::endl;
117-
118-
// MR
119-
std::unique_ptr<MR> mr = std::make_unique<MR>();
120-
int ret = reg_mr(factory_dev, aligned_ptr, bytes, &mr->mhandle_);
121-
if (ret != 0) {
122-
GPU_RT_CHECK(gpuFree(raw_ptr));
123-
throw std::runtime_error("MR registration failed");
124-
}
125-
mr->mr_id_ = next_mr_id_.fetch_add(1);
70+
void reg_mem(int gpu_id, void* addr, size_t size, uint64_t& mem_id) {
71+
if (gpu_id < 0 || gpu_id >= kMaxNumGPUs) {
72+
throw std::invalid_argument("[reg_mem] Invalid gpu_id: " +
73+
std::to_string(gpu_id));
74+
}
75+
76+
if (gpu_to_dev[gpu_id] == 0) {
77+
throw std::runtime_error(
78+
"You must initialize UCCL collective context or Endpoint first");
79+
}
80+
81+
GPU_RT_CHECK(gpuSetDevice(gpu_id));
82+
83+
uccl::FactoryDevice* factory_dev =
84+
uccl::RDMAFactory::get_factory_dev(gpu_to_dev[gpu_id]);
85+
86+
mem_id = next_mem_id_.fetch_add(1);
87+
// MR
88+
std::unique_ptr<MR> mr = std::make_unique<MR>();
89+
int ret = reg_mr(factory_dev, addr, size, &mr->mhandle_);
90+
if (ret != 0) {
91+
throw std::runtime_error("MR registration failed");
92+
}
93+
mr->mr_id_ = mem_id;
94+
{
95+
std::unique_lock<std::shared_mutex> lock(mr_mapping_mu_);
96+
mr_mapping_[mr->mr_id_] = std::move(mr);
97+
}
98+
99+
// IPC
100+
auto addr_aligned = reinterpret_cast<uintptr_t>(addr) & ~(kIpcAlignment - 1);
101+
auto addr_offset = reinterpret_cast<uintptr_t>(addr) - addr_aligned;
102+
// std::cout << "[reg_mem] Aligned pointer=" << addr_aligned << std::endl;
103+
104+
std::unique_ptr<IPCMemHandle> ipc = std::make_unique<IPCMemHandle>();
105+
ret = get_ipc_handle(reinterpret_cast<void*>(addr_aligned), ipc.get());
106+
if (ret != 0) {
126107
{
127-
std::unique_lock<std::shared_mutex> lock(mr_mapping_mu_);
128-
mr_mapping_[mr->mr_id_] = std::move(mr);
108+
std::unique_lock<std::shared_mutex> lock(mr_mapping_mu_);
109+
mr_mapping_.erase(mem_id);
129110
}
130-
mr_id = mr->mr_id_;
131-
std::cout << "[create_tensor] MR registered, mr_id=" << mr_id << std::endl;
132-
133-
// IPC
134-
std::unique_ptr<IPCMemHandle> ipc = std::make_unique<IPCMemHandle>();
135-
ret = get_ipc_handle(aligned_ptr, ipc.get());
136-
if (ret != 0) {
137-
{
138-
std::unique_lock<std::shared_mutex> lock(mr_mapping_mu_);
139-
mr_mapping_.erase(mr_id);
140-
}
141-
GPU_RT_CHECK(gpuFree(raw_ptr));
142-
throw std::runtime_error("IPC handle creation failed");
143-
}
144-
ipc->id = next_ipc_id_.fetch_add(1);
145-
ipc->size = bytes;
146-
{
147-
std::unique_lock<std::shared_mutex> lock(ipc_handle_mapping_mu_);
148-
ipc_handle_mapping_[ipc->id] = std::move(ipc);
149-
}
150-
ipc_id = ipc->id;
151-
std::cout << "[create_tensor] IPC handle created, ipc_id=" << ipc_id
152-
<< " size=" << bytes << std::endl;
153-
154-
std::cout << "[create_tensor] SUCCESS: returning tensor with mr_id=" << mr_id
155-
<< " ipc_id=" << ipc_id << std::endl;
156-
157-
return tensor;
111+
throw std::runtime_error("[reg_mem] IPC handle creation failed");
112+
}
113+
ipc->size = size;
114+
ipc->offset = addr_offset;
115+
ipc->id = mem_id;
116+
{
117+
std::unique_lock<std::shared_mutex> lock(ipc_handle_mapping_mu_);
118+
ipc_handle_mapping_[ipc->id] = std::move(ipc);
119+
}
158120
}
159121

160-
161-
void free_tensor(torch::Tensor& tensor, uint64_t mr_id, uint64_t ipc_id) {
162-
if (tensor.defined()) {
163-
tensor.reset();
164-
}
165-
166-
{
167-
std::unique_lock<std::shared_mutex> lock(mr_mapping_mu_);
168-
auto it = mr_mapping_.find(mr_id);
169-
if (it != mr_mapping_.end()) {
170-
dereg_mr(it->second->mhandle_);
171-
mr_mapping_.erase(it);
172-
} else {
173-
std::cerr << "[free_tensor] MR id " << mr_id << " not found!\n";
174-
}
122+
void dereg_mem(uint64_t mem_id) {
123+
{
124+
std::unique_lock<std::shared_mutex> lock(mr_mapping_mu_);
125+
auto it = mr_mapping_.find(mem_id);
126+
if (it != mr_mapping_.end()) {
127+
dereg_mr(it->second->mhandle_);
128+
mr_mapping_.erase(it);
129+
} else {
130+
std::cerr << "[free_tensor] MR id " << mem_id << " not found!\n";
175131
}
176-
177-
{
178-
std::unique_lock<std::shared_mutex> lock(ipc_handle_mapping_mu_);
179-
auto it = ipc_handle_mapping_.find(ipc_id);
180-
if (it != ipc_handle_mapping_.end()) {
181-
ipc_handle_mapping_.erase(it);
182-
} else {
183-
std::cerr << "[free_tensor] IPC id " << ipc_id << " not found!\n";
184-
}
132+
}
133+
{
134+
std::unique_lock<std::shared_mutex> lock(ipc_handle_mapping_mu_);
135+
auto it = ipc_handle_mapping_.find(mem_id);
136+
if (it != ipc_handle_mapping_.end()) {
137+
ipc_handle_mapping_.erase(it);
138+
} else {
139+
std::cerr << "[free_tensor] IPC id " << mem_id << " not found!\n";
185140
}
141+
}
186142
}
187143

188-
ibv_mr* get_mr_ibv_mr(uint64_t mr_id) {
189-
std::shared_lock<std::shared_mutex> lock(mr_mapping_mu_);
190-
auto it = mr_mapping_.find(mr_id);
191-
if (it != mr_mapping_.end()) {
192-
return it->second->mhandle_->mr;
193-
}
194-
return nullptr;
144+
ibv_mr* get_ibv_mr_by_mem_id(uint64_t mem_id) {
145+
std::shared_lock<std::shared_mutex> lock(mr_mapping_mu_);
146+
auto it = mr_mapping_.find(mem_id);
147+
if (it != mr_mapping_.end()) {
148+
return it->second->mhandle_->mr;
149+
}
150+
return nullptr;
195151
}
196152

197-
gpuIpcMemHandle_t get_ipc_mem_handle(uint64_t ipc_id) {
198-
std::shared_lock<std::shared_mutex> lock(ipc_handle_mapping_mu_);
199-
auto it = ipc_handle_mapping_.find(ipc_id);
200-
if (it != ipc_handle_mapping_.end()) {
201-
return it->second->handle;
202-
}
203-
gpuIpcMemHandle_t handle = {};
204-
return handle;
153+
gpuIpcMemHandle_t get_ipc_mem_handle_by_mem_id(uint64_t mem_id) {
154+
std::shared_lock<std::shared_mutex> lock(ipc_handle_mapping_mu_);
155+
auto it = ipc_handle_mapping_.find(mem_id);
156+
if (it != ipc_handle_mapping_.end()) {
157+
return it->second->handle;
158+
}
159+
gpuIpcMemHandle_t handle = {};
160+
return handle;
205161
}

0 commit comments

Comments
 (0)