Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions apps/nccl/src/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ __global__ void __launch_bounds__(1024, 1)
}
}
}

deviceSyncer.sync(gridDim.x);

if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
}

template <bool IsOutOfPlace, typename T>
Expand Down
9 changes: 5 additions & 4 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include "common.hpp"
#include "gpu_data_types.hpp"

__device__ mscclpp::DeviceSyncer deviceSyncer;

template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
Expand Down Expand Up @@ -236,14 +234,13 @@ __global__ void __launch_bounds__(1024, 1)
blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint32_t data = 0;
uint32_t data = src[idx];
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank;
uint32_t val = dstPkt[idx].read(flag, -1);
data = add_vectors<T>(val, data);
}
data = add_vectors<T>(data, src[idx]);
dst[idx] = data;

mscclpp::LL8Packet packet;
Expand Down Expand Up @@ -384,6 +381,10 @@ __global__ void __launch_bounds__(512, 1)
}
}
}
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].signal();
outChannels[threadIdx.x].wait();
}
}

template <typename T>
Expand Down
4 changes: 4 additions & 0 deletions apps/nccl/src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef NCCL_COMMON_HPP_
#define NCCL_COMMON_HPP_

#include <mscclpp/concurrency_device.hpp>

#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64
#define __syncwarp() __builtin_amdgcn_wave_barrier()
Expand All @@ -14,4 +16,6 @@
constexpr int NRANKS_PER_NODE = 8;
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB

__device__ mscclpp::DeviceSyncer deviceSyncer;

#endif // NCCL_COMMON_HPP_
9 changes: 5 additions & 4 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,10 @@ NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) {
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t*, int, ncclUniqueId, int, ncclConfig_t*) {
// TODO: implement this function
return ncclInternalError;
NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank,
ncclConfig_t*) {
// TODO: implement config
return ncclCommInitRank(comm, nranks, commId, rank);
}

NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
Expand Down Expand Up @@ -419,7 +420,7 @@ NCCL_API const char* ncclGetErrorString(ncclResult_t result) {

NCCL_API const char* ncclGetLastError(ncclComm_t) {
// TODO: implement this function
return nullptr;
return "";
}

NCCL_API ncclResult_t ncclCommGetAsyncError(ncclComm_t, ncclResult_t* asyncError) {
Expand Down