|
12 | 12 | #include <memory> |
13 | 13 | #include <mscclpp/env.hpp> |
14 | 14 | #include <mscclpp/errors.hpp> |
| 15 | +#include <mscclpp/gpu_utils.hpp> |
15 | 16 | #include <sstream> |
16 | 17 | #include <string> |
17 | 18 |
|
@@ -232,4 +233,42 @@ void getRandomData(void* buffer, size_t bytes) { |
232 | 233 | } |
233 | 234 | } |
234 | 235 |
|
| 236 | +TokenPool::TokenPool(size_t nToken) : nToken_(nToken) { |
| 237 | +#if (CUDA_NVLS_API_AVAILABLE) |
| 238 | + tokens_ = detail::gpuCallocPhysicalShared<uint64_t>( |
| 239 | + nToken, detail::getCuAllocationGranularity(CU_MEM_ALLOC_GRANULARITY_MINIMUM)); |
| 240 | + MSCCLPP_CUTHROW(cuMemGetAddressRange((CUdeviceptr*)(&baseAddr_), NULL, (CUdeviceptr)tokens_.get())); |
| 241 | + size_t nElems = (nToken + (UINT64_WIDTH - 1)) / UINT64_WIDTH; |
| 242 | + allocationMap_.resize(nElems, 0); |
| 243 | + tailMask_ = (nToken % UINT64_WIDTH) ? ((1UL << (nToken % UINT64_WIDTH)) - 1) : ~0UL; |
| 244 | +#else |
| 245 | + throw Error("TokenPool only available on GPUs with NVLS support", ErrorCode::InvalidUsage); |
| 246 | +#endif |
| 247 | +} |
| 248 | + |
| 249 | +std::shared_ptr<uint64_t> TokenPool::getToken() { |
| 250 | + auto deleter = [self = shared_from_this()](uint64_t* token) { |
| 251 | + size_t index = (token - self->baseAddr_) / UINT64_WIDTH; |
| 252 | + size_t bit = (token - self->baseAddr_) % UINT64_WIDTH; |
| 253 | + uint64_t mask = 1UL << bit; |
| 254 | + self->allocationMap_[index] &= ~mask; |
| 255 | + }; |
| 256 | + |
| 257 | + size_t size = allocationMap_.size(); |
| 258 | + for (size_t i = 0; i < size; i++) { |
| 259 | + uint64_t ullong = allocationMap_[i].to_ullong(); |
| 260 | + uint64_t mask = (i + 1 == size) ? tailMask_ : ~0ULL; |
| 261 | + uint64_t holes = (~ullong) & mask; |
| 262 | + if (!holes) continue; |
| 263 | + for (int bit = 0; bit < UINT64_WIDTH; bit++) { |
| 264 | + if (holes & (1UL << bit)) { |
| 265 | + allocationMap_[i].set(bit); |
| 266 | + INFO(MSCCLPP_ALLOC, "TokenPool allocated token at addr %p", baseAddr_ + i * UINT64_WIDTH + bit); |
| 267 | + return std::shared_ptr<uint64_t>(baseAddr_ + i * UINT64_WIDTH + bit, deleter); |
| 268 | + } |
| 269 | + } |
| 270 | + } |
| 271 | + throw Error("TokenPool is exhausted", ErrorCode::InternalError); |
| 272 | +} |
| 273 | + |
235 | 274 | } // namespace mscclpp |
0 commit comments