Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions apps/nccl/src/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ struct Allreduce8Adapter {
}
};

template <Op OpType, typename T>
struct AllreduceNvlsPacketAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void*, void*,
mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t scratchBufferSize,
int rank, int, int worldSize, size_t nelems, cudaStream_t stream, uint32_t* deviceFlag,
uint32_t*, uint32_t*, uint32_t) {
size_t size = nelems * sizeof(T);
int nBlocks = 8;
int nThreadsPerBlock = 1024;
if (size <= (1 << 13)) {
nBlocks = 4;
nThreadsPerBlock = 512;
}
allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(const T*)input, (T*)scratch, (T*)output, nvlsChannels, nelems, scratchBufferSize, rank, worldSize, deviceFlag);
return cudaGetLastError();
}
};

template <template <Op, typename> class Adapter>
AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
Op reduceOp = getReduceOp(op);
Expand Down Expand Up @@ -558,4 +578,76 @@ mscclpp::Algorithm Allreduce8::build() {
return self->generateAllreduceContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
});
return allreduceAlgo;
}

void AllreduceNvlsPacket::initialize(std::shared_ptr<mscclpp::Communicator>,
std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
this->scratchBuffer_ = std::static_pointer_cast<char>(extras.at("scratch"));
this->scratchBufferSize_ = *(size_t*)(extras.at("scratch_size").get());
deviceFlag_ = mscclpp::detail::gpuCallocShared<uint32_t>(16);
std::vector<uint32_t> initFlag(16);
for (int i = 0; i < 16; ++i) {
initFlag[i] = 1;
}
mscclpp::gpuMemcpy<uint32_t>(deviceFlag_.get(), initFlag.data(), 16, cudaMemcpyHostToDevice);
}

mscclpp::AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, ncclDataType_t) {
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
}

std::shared_ptr<mscclpp::AlgorithmCtx> AllreduceNvlsPacket::initAllreduceContext(
std::shared_ptr<mscclpp::Communicator> comm, const void*, void*, size_t, ncclDataType_t) {
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();

// setup channels
int nSwitchChannels = 1;
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
ctx->switchChannels =
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_.get(), this->scratchBufferSize_, nSwitchChannels);
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}

ncclResult_t AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx,
const void* input, void* output, size_t count,
ncclDataType_t dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>& extra) {
int op = *static_cast<int*>(extra.at("op").get());
AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(static_cast<ncclRedOp_t>(op), dtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", dtype);
return ncclInvalidArgument;
}
cudaError_t error =
allreduce(input, this->scratchBuffer_.get(), output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(),
nullptr, 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, count, stream,
this->deviceFlag_.get(), nullptr, nullptr, 0);
if (error != cudaSuccess) {
WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error));
return ncclUnhandledCudaError;
}
return ncclSuccess;
}

mscclpp::Algorithm AllreduceNvlsPacket::build() {
auto self = std::make_shared<AllreduceNvlsPacket>();
mscclpp::Algorithm allreduceAlgo(
"default_allreduce_nvls_packet", "allreduce",
[self](std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
return self->allreduceKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
return self->initAllreduceContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
},
[self](const void* input, void* output, size_t count, int dtype) {
return self->generateAllreduceContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
});
return allreduceAlgo;
}
63 changes: 63 additions & 0 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,45 @@ __global__ void __launch_bounds__(1024, 1)
#endif
}

template <Op OpType, typename T>
__global__ void __launch_bounds__(1024, 1)
allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output,
[[maybe_unused]] mscclpp::DeviceHandle<mscclpp::SwitchChannel>* multicast,
[[maybe_unused]] size_t nelems, [[maybe_unused]] size_t scratchBufferSize,
[[maybe_unused]] int rank, [[maybe_unused]] int worldSize,
[[maybe_unused]] uint32_t* deviceFlag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t flag = deviceFlag[blockIdx.x];
size_t scratchBaseOffset = (flag % 2) ? scratchBufferSize / 2 : 0;
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
uint32_t nPktPerRank = nelems / worldSize / (sizeof(mscclpp::LLPacket::Payload) / sizeof(T));
mscclpp::LLPacket* multiPkt =
(mscclpp::LLPacket*)((char*)multicast->mcPtr + scratchBaseOffset) + rank * worldSize * nPktPerRank;
uint2* src2 = (uint2*)(input);
uint2* dst2 = (uint2*)output;
mscclpp::LLPacket* scratchPkt = (mscclpp::LLPacket*)((char*)scratch + scratchBaseOffset);
for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) {
mscclpp::LLPacket pkt(src2[i], flag);
mscclpp::SwitchChannelDeviceHandle::multimemStore(pkt, multiPkt + i);
}
for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) {
uint2 data = src2[i];
for (int peer = 0; peer < worldSize; peer++) {
if (peer == rank) {
continue;
}
uint2 val = scratchPkt[peer * worldSize * nPktPerRank + i].read(flag);
data.x = cal_vectors<T, OpType>(data.x, val.x);
data.y = cal_vectors<T, OpType>(data.y, val.y);
}
dst2[i] = data;
}
if (threadIdx.x == 0) {
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
}
#endif
}

enum Op getReduceOp(ncclRedOp_t op);

class AllreducePacket : public mscclpp::AlgorithmBuilder {
Expand Down Expand Up @@ -896,4 +935,28 @@ class Allreduce8 : public mscclpp::AlgorithmBuilder {
memoryChannelsMap_;
};

class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
public:
mscclpp::Algorithm build() override;

private:
void initialize(std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
ncclResult_t allreduceKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
size_t count, ncclDataType_t dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>& extras);

std::shared_ptr<mscclpp::AlgorithmCtx> initAllreduceContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
void* output, size_t, ncclDataType_t);
mscclpp::AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, ncclDataType_t);

size_t scratchBufferSize_;
std::shared_ptr<char> scratchBuffer_;
const int nSegmentsForScratchBuffer_ = 2;
const size_t nvlsBufferSize_ = (1 << 30);

std::shared_ptr<uint32_t> deviceFlag_;
std::shared_ptr<mscclpp::AlgorithmCtx> ctx_;
};

#endif // ALLREDUCE_KERNEL_H
29 changes: 18 additions & 11 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,12 @@ static void registerCustomizedAlgo() {
std::shared_ptr<AllreduceNvls> allreduceNvlsAlgo = std::make_shared<AllreduceNvls>();
std::shared_ptr<AllreduceNvlsWithCopy> allreduceNvlsWithCopyAlgo = std::make_shared<AllreduceNvlsWithCopy>();
std::shared_ptr<Allreduce8> allreduceAllreduce8Algo = std::make_shared<Allreduce8>();
std::shared_ptr<AllreduceNvlsPacket> allreduceNvlsPacketAlgo = std::make_shared<AllreduceNvlsPacket>();
collectionBuilder->addAlgorithmBuilder(allreduceAllpairAlgo);
collectionBuilder->addAlgorithmBuilder(allreduceNvlsAlgo);
collectionBuilder->addAlgorithmBuilder(allreduceNvlsWithCopyAlgo);
collectionBuilder->addAlgorithmBuilder(allreduceAllreduce8Algo);
collectionBuilder->addAlgorithmBuilder(allreduceNvlsPacketAlgo);
}

static mscclpp::Algorithm algoSelector(
Expand All @@ -284,10 +286,11 @@ static mscclpp::Algorithm algoSelector(
// Fallback to nccl/rccl when multi-node
return mscclpp::Algorithm();
}
bool isCuMemMapAllocated =
static bool isCuMemMapAllocated =
mscclpp::isCuMemMapAllocated(const_cast<void*>(input)) && mscclpp::isCuMemMapAllocated(output);
bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache;
bool useNvlsWithZeroCopy = mscclpp::isNvlsSupported() && !mscclppDisableChannelCache && isCuMemMapAllocated;
static bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache;
static bool isNvlsSupported = mscclpp::isNvlsSupported();
static bool useNvlsWithZeroCopy = isNvlsSupported && !mscclppDisableChannelCache && isCuMemMapAllocated;
if (collective == "allgather") {
if (messageSize <= 32 * (1 << 20)) {
return algoMapByCollective.at(collective).at("default_allgather6");
Expand All @@ -302,21 +305,25 @@ static mscclpp::Algorithm algoSelector(
}
}
if (collective == "allreduce") {
if (messageSize <= (1 << 15) && isNvlsSupported) {
return algoMapByCollective.at(collective).at("default_allreduce_nvls_packet");
}
if (messageSize <= (1 << 16) || (messageSize <= (1 << 20) && !useNvlsWithZeroCopy)) {
return algoMapByCollective.at(collective).at("default_allreduce_packet");
} else if (useNvlsWithZeroCopy) {
}
if (useNvlsWithZeroCopy) {
return algoMapByCollective.at(collective).at("default_allreduce_nvls");
} else if (mscclpp::isNvlsSupported()) {
}
if (mscclpp::isNvlsSupported()) {
return algoMapByCollective.at(collective).at("default_allreduce_nvls_with_copy");
} else {
}
#if defined(__HIP_PLATFORM_AMD__)
return algoMapByCollective.at(collective).at("default_allreduce_allreduce8");
return algoMapByCollective.at(collective).at("default_allreduce_allreduce8");
#else
if (!mscclppNcclDlopenSharedLib) {
return algoMapByCollective.at(collective).at("default_allreduce_allreduce8");
}
#endif
if (!mscclppNcclDlopenSharedLib) {
return algoMapByCollective.at(collective).at("default_allreduce_allreduce8");
}
#endif
}
INFO(MSCCLPP_NCCL, "Failed to get algo from customized kernel, fallback to nccl/rccl");
return mscclpp::Algorithm();
Expand Down
13 changes: 13 additions & 0 deletions include/mscclpp/switch_channel_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <mscclpp/gpu_data_types.hpp>

#include "device.hpp"
#include "packet_device.hpp"

namespace mscclpp {

Expand All @@ -37,6 +38,11 @@ struct SwitchChannelDeviceHandle {
SwitchChannelDeviceHandle::multimemStore(val, reinterpret_cast<T*>(mcPtr) + index);
}

template <typename>
MSCCLPP_DEVICE_INLINE void broadcast<LLPacket>(uint64_t index, const LLPacket& val) {
multimemStore(val, reinterpret_cast<LLPacket*>(mcPtr) + index);
}

template <typename VectorType>
MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) {
VectorType val;
Expand Down Expand Up @@ -129,6 +135,13 @@ struct SwitchChannelDeviceHandle {
}
};

template <typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const LLPacket& val, T* ptr) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.data1),
"r"(val.flag1), "r"(val.data2), "r"(val.flag2)
: "memory");
}

template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) {
if constexpr (std::is_same_v<TValue, float4> && std::is_same_v<T, float>) {
Expand Down
Loading