11#include " tensor.h"
22#include < cerrno>
3- #include < stdexcept>
43#include < iostream>
4+ #include < stdexcept>
55
66std::shared_mutex mr_mapping_mu_;
77std::unordered_map<uint64_t , std::unique_ptr<MR>> mr_mapping_;
8- std::atomic<uint64_t > next_mr_id_{0 };
98
109std::shared_mutex ipc_handle_mapping_mu_;
1110std::unordered_map<uint64_t , std::unique_ptr<IPCMemHandle>> ipc_handle_mapping_;
12- std::atomic<uint64_t > next_ipc_id_{0 };
1311
14- int reg_dma_mr (uccl::FactoryDevice* dev, void * addr, size_t len, int type, int offset,
12+ std::atomic<uint64_t > next_mem_id_{0 };
13+
14+ int reg_dma_mr (uccl::FactoryDevice* dev, void * addr, size_t len, int offset,
1515 int fd, struct uccl ::Mhandle** mhandle) {
16- bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable ();
16+ bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable ();
1717
18- unsigned int flags =
19- IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
20- if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
18+ unsigned int flags =
19+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
20+ if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
2121
22- *mhandle = new uccl::Mhandle ();
23- (*mhandle)->mr = ibv_reg_dmabuf_mr (dev-> pd , offset, len,
24- (uint64_t )addr, fd, flags);
25- return 0 ;
22+ *mhandle = new uccl::Mhandle ();
23+ (*mhandle)->mr =
24+ ibv_reg_dmabuf_mr (dev-> pd , offset, len, (uint64_t )addr, fd, flags);
25+ return 0 ;
2626}
2727
28- int reg_mr (uccl::FactoryDevice* dev, void * addr, size_t len, struct uccl ::Mhandle** mhandle) {
29- bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable ();
30-
31- unsigned int flags =
32- IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
33- if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
34-
35- *mhandle = new uccl::Mhandle ();
36- if (ib_relaxed_ordering_enabled_) {
37- (*mhandle)->mr =
38- ibv_reg_mr_iova2 (dev->pd , addr, len, (uint64_t )addr, flags);
39- } else {
40- (*mhandle)->mr = ibv_reg_mr (dev->pd , addr, len, flags);
41- }
42- if (!(*mhandle)->mr ) {
43- std::cerr << " ibv_reg_mr failed (" << strerror (errno) << " ), len=" << len
44- << " addr=" << addr << " \n " ;
45- delete *mhandle;
46- *mhandle = nullptr ;
47- return -1 ;
48- }
49- return 0 ;
28+ int reg_mr (uccl::FactoryDevice* dev, void * addr, size_t len,
29+ struct uccl ::Mhandle** mhandle) {
30+ bool ib_relaxed_ordering_enabled_ = uccl::ncclIbRelaxedOrderingCapable ();
31+
32+ unsigned int flags =
33+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
34+ if (ib_relaxed_ordering_enabled_) flags |= IBV_ACCESS_RELAXED_ORDERING;
35+
36+ *mhandle = new uccl::Mhandle ();
37+ if (ib_relaxed_ordering_enabled_) {
38+ (*mhandle)->mr =
39+ ibv_reg_mr_iova2 (dev->pd , addr, len, (uint64_t )addr, flags);
40+ } else {
41+ (*mhandle)->mr = ibv_reg_mr (dev->pd , addr, len, flags);
42+ }
43+ if (!(*mhandle)->mr ) {
44+ std::cerr << " ibv_reg_mr failed (" << strerror (errno) << " ), len=" << len
45+ << " addr=" << addr << " \n " ;
46+ delete *mhandle;
47+ *mhandle = nullptr ;
48+ return -1 ;
49+ }
50+ return 0 ;
5051}
5152
5253void dereg_mr (struct uccl ::Mhandle* mhandle) {
53- ibv_dereg_mr (mhandle->mr );
54- delete mhandle;
54+ ibv_dereg_mr (mhandle->mr );
55+ delete mhandle;
5556}
5657
5758int get_ipc_handle (void * addr, struct IPCMemHandle * ipchandle) {
58- GPU_RT_CHECK (gpuIpcGetMemHandle (&ipchandle->handle , reinterpret_cast <void *>(addr)));
59- return 0 ;
59+ GPU_RT_CHECK (
60+ gpuIpcGetMemHandle (&ipchandle->handle , reinterpret_cast <void *>(addr)));
61+ return 0 ;
6062}
6163
6264int open_ipc_handle (void * addr, struct IPCMemHandle * ipchandle) {
63- GPU_RT_CHECK (gpuIpcOpenMemHandle (&addr, ipchandle->handle ,
65+ GPU_RT_CHECK (gpuIpcOpenMemHandle (&addr, ipchandle->handle ,
6466 gpuIpcMemLazyEnablePeerAccess));
65- return 0 ;
67+ return 0 ;
6668}
6769
68- torch::Dtype torch_dtype_from_size (size_t dtype_size) {
69- switch (dtype_size) {
70- case 1 :
71- return torch::kInt8 ;
72- case 2 :
73- return torch::kInt16 ;
74- case 4 :
75- return torch::kInt32 ;
76- case 8 :
77- return torch::kInt64 ;
78- default :
79- throw std::runtime_error (" Unsupported dtype size: " +
80- std::to_string (dtype_size));
81- }
82- }
83-
84- torch::Tensor create_tensor (int gpu_index, size_t num_elems, size_t dtype_size,
85- uint64_t & mr_id, uint64_t & ipc_id, bool requires_grad) {
86- std::cout << " [create_tensor] gpu_index=" << gpu_index
87- << " num_elems=" << num_elems
88- << " dtype_size=" << dtype_size
89- << " requires_grad=" << requires_grad << std::endl;
90-
91- GPU_RT_CHECK (gpuSetDevice (gpu_index));
92- uccl::FactoryDevice* factory_dev = uccl::RDMAFactory::get_factory_dev (gpu_to_dev[gpu_index]);
93- std::cout << " [create_tensor] Got factory_dev for gpu_index " << gpu_index << std::endl;
94-
95- size_t bytes = num_elems * dtype_size;
96- size_t alignment = kIpcAlignment ;
97- std::cout << " [create_tensor] Allocating " << bytes << " bytes (aligned to " << alignment << " )" << std::endl;
98-
99- void * raw_ptr;
100- GPU_RT_CHECK (gpuMalloc (&raw_ptr, bytes + alignment));
101- std::cout << " [create_tensor] gpuMalloc success, raw_ptr=" << raw_ptr << std::endl;
102-
103- uintptr_t aligned_addr = (reinterpret_cast <uintptr_t >(raw_ptr) + alignment - 1 ) & ~(alignment - 1 );
104- void * aligned_ptr = reinterpret_cast <void *>(aligned_addr);
105- std::cout << " [create_tensor] Aligned pointer=" << aligned_ptr << std::endl;
106-
107- // Tensor
108- auto dtype_ = torch_dtype_from_size (dtype_size);
109- auto dev = torch::Device (torch_dev, gpu_index);
110- auto options = torch::TensorOptions ().dtype (dtype_).device (dev).requires_grad (requires_grad);
111- auto deleter = [raw_ptr](void * ptr) {
112- std::cout << " [create_tensor] Deleter freeing raw_ptr=" << raw_ptr << std::endl;
113- GPU_RT_CHECK (gpuFree (raw_ptr));
114- };
115- torch::Tensor tensor = torch::from_blob (aligned_ptr, {static_cast <long >(num_elems)}, deleter, options);
116- std::cout << " [create_tensor] Torch tensor created: sizes=" << tensor.sizes () << std::endl;
117-
118- // MR
119- std::unique_ptr<MR> mr = std::make_unique<MR>();
120- int ret = reg_mr (factory_dev, aligned_ptr, bytes, &mr->mhandle_ );
121- if (ret != 0 ) {
122- GPU_RT_CHECK (gpuFree (raw_ptr));
123- throw std::runtime_error (" MR registration failed" );
124- }
125- mr->mr_id_ = next_mr_id_.fetch_add (1 );
70+ void reg_mem (int gpu_id, void * addr, size_t size, uint64_t & mem_id) {
71+ if (gpu_id < 0 || gpu_id >= kMaxNumGPUs ) {
72+ throw std::invalid_argument (" [reg_mem] Invalid gpu_id: " +
73+ std::to_string (gpu_id));
74+ }
75+
76+ if (gpu_to_dev[gpu_id] == 0 ) {
77+ throw std::runtime_error (
78+ " You must initialize UCCL collective context or Endpoint first" );
79+ }
80+
81+ GPU_RT_CHECK (gpuSetDevice (gpu_id));
82+
83+ uccl::FactoryDevice* factory_dev =
84+ uccl::RDMAFactory::get_factory_dev (gpu_to_dev[gpu_id]);
85+
86+ mem_id = next_mem_id_.fetch_add (1 );
87+ // MR
88+ std::unique_ptr<MR> mr = std::make_unique<MR>();
89+ int ret = reg_mr (factory_dev, addr, size, &mr->mhandle_ );
90+ if (ret != 0 ) {
91+ throw std::runtime_error (" MR registration failed" );
92+ }
93+ mr->mr_id_ = mem_id;
94+ {
95+ std::unique_lock<std::shared_mutex> lock (mr_mapping_mu_);
96+ mr_mapping_[mr->mr_id_ ] = std::move (mr);
97+ }
98+
99+ // IPC
100+ auto addr_aligned = reinterpret_cast <uintptr_t >(addr) & ~(kIpcAlignment - 1 );
101+ auto addr_offset = reinterpret_cast <uintptr_t >(addr) - addr_aligned;
102+ // std::cout << "[reg_mem] Aligned pointer=" << addr_aligned << std::endl;
103+
104+ std::unique_ptr<IPCMemHandle> ipc = std::make_unique<IPCMemHandle>();
105+ ret = get_ipc_handle (reinterpret_cast <void *>(addr_aligned), ipc.get ());
106+ if (ret != 0 ) {
126107 {
127- std::unique_lock<std::shared_mutex> lock (mr_mapping_mu_);
128- mr_mapping_[mr-> mr_id_ ] = std::move (mr );
108+ std::unique_lock<std::shared_mutex> lock (mr_mapping_mu_);
109+ mr_mapping_. erase (mem_id );
129110 }
130- mr_id = mr->mr_id_ ;
131- std::cout << " [create_tensor] MR registered, mr_id=" << mr_id << std::endl;
132-
133- // IPC
134- std::unique_ptr<IPCMemHandle> ipc = std::make_unique<IPCMemHandle>();
135- ret = get_ipc_handle (aligned_ptr, ipc.get ());
136- if (ret != 0 ) {
137- {
138- std::unique_lock<std::shared_mutex> lock (mr_mapping_mu_);
139- mr_mapping_.erase (mr_id);
140- }
141- GPU_RT_CHECK (gpuFree (raw_ptr));
142- throw std::runtime_error (" IPC handle creation failed" );
143- }
144- ipc->id = next_ipc_id_.fetch_add (1 );
145- ipc->size = bytes;
146- {
147- std::unique_lock<std::shared_mutex> lock (ipc_handle_mapping_mu_);
148- ipc_handle_mapping_[ipc->id ] = std::move (ipc);
149- }
150- ipc_id = ipc->id ;
151- std::cout << " [create_tensor] IPC handle created, ipc_id=" << ipc_id
152- << " size=" << bytes << std::endl;
153-
154- std::cout << " [create_tensor] SUCCESS: returning tensor with mr_id=" << mr_id
155- << " ipc_id=" << ipc_id << std::endl;
156-
157- return tensor;
111+ throw std::runtime_error (" [reg_mem] IPC handle creation failed" );
112+ }
113+ ipc->size = size;
114+ ipc->offset = addr_offset;
115+ ipc->id = mem_id;
116+ {
117+ std::unique_lock<std::shared_mutex> lock (ipc_handle_mapping_mu_);
118+ ipc_handle_mapping_[ipc->id ] = std::move (ipc);
119+ }
158120}
159121
160-
161- void free_tensor (torch::Tensor& tensor, uint64_t mr_id, uint64_t ipc_id) {
162- if (tensor.defined ()) {
163- tensor.reset ();
164- }
165-
166- {
167- std::unique_lock<std::shared_mutex> lock (mr_mapping_mu_);
168- auto it = mr_mapping_.find (mr_id);
169- if (it != mr_mapping_.end ()) {
170- dereg_mr (it->second ->mhandle_ );
171- mr_mapping_.erase (it);
172- } else {
173- std::cerr << " [free_tensor] MR id " << mr_id << " not found!\n " ;
174- }
122+ void dereg_mem (uint64_t mem_id) {
123+ {
124+ std::unique_lock<std::shared_mutex> lock (mr_mapping_mu_);
125+ auto it = mr_mapping_.find (mem_id);
126+ if (it != mr_mapping_.end ()) {
127+ dereg_mr (it->second ->mhandle_ );
128+ mr_mapping_.erase (it);
129+ } else {
130+ std::cerr << " [free_tensor] MR id " << mem_id << " not found!\n " ;
175131 }
176-
177- {
178- std::unique_lock<std::shared_mutex> lock (ipc_handle_mapping_mu_);
179- auto it = ipc_handle_mapping_.find (ipc_id);
180- if (it != ipc_handle_mapping_.end ()) {
181- ipc_handle_mapping_.erase (it);
182- } else {
183- std::cerr << " [free_tensor] IPC id " << ipc_id << " not found!\n " ;
184- }
132+ }
133+ {
134+ std::unique_lock<std::shared_mutex> lock (ipc_handle_mapping_mu_);
135+ auto it = ipc_handle_mapping_.find (mem_id);
136+ if (it != ipc_handle_mapping_.end ()) {
137+ ipc_handle_mapping_.erase (it);
138+ } else {
139+ std::cerr << " [free_tensor] IPC id " << mem_id << " not found!\n " ;
185140 }
141+ }
186142}
187143
188- ibv_mr* get_mr_ibv_mr (uint64_t mr_id ) {
189- std::shared_lock<std::shared_mutex> lock (mr_mapping_mu_);
190- auto it = mr_mapping_.find (mr_id );
191- if (it != mr_mapping_.end ()) {
192- return it->second ->mhandle_ ->mr ;
193- }
194- return nullptr ;
144+ ibv_mr* get_ibv_mr_by_mem_id (uint64_t mem_id ) {
145+ std::shared_lock<std::shared_mutex> lock (mr_mapping_mu_);
146+ auto it = mr_mapping_.find (mem_id );
147+ if (it != mr_mapping_.end ()) {
148+ return it->second ->mhandle_ ->mr ;
149+ }
150+ return nullptr ;
195151}
196152
197- gpuIpcMemHandle_t get_ipc_mem_handle (uint64_t ipc_id ) {
198- std::shared_lock<std::shared_mutex> lock (ipc_handle_mapping_mu_);
199- auto it = ipc_handle_mapping_.find (ipc_id );
200- if (it != ipc_handle_mapping_.end ()) {
201- return it->second ->handle ;
202- }
203- gpuIpcMemHandle_t handle = {};
204- return handle;
153+ gpuIpcMemHandle_t get_ipc_mem_handle_by_mem_id (uint64_t mem_id ) {
154+ std::shared_lock<std::shared_mutex> lock (ipc_handle_mapping_mu_);
155+ auto it = ipc_handle_mapping_.find (mem_id );
156+ if (it != ipc_handle_mapping_.end ()) {
157+ return it->second ->handle ;
158+ }
159+ gpuIpcMemHandle_t handle = {};
160+ return handle;
205161}
0 commit comments