diff --git a/apps/nccl/src/allreduce.cu b/apps/nccl/src/allreduce.cu index 0a48db61b..bac933a01 100644 --- a/apps/nccl/src/allreduce.cu +++ b/apps/nccl/src/allreduce.cu @@ -123,6 +123,26 @@ struct Allreduce8Adapter { } }; +template +struct AllreduceNvlsPacketAdapter { + static cudaError_t call(const void* input, void* scratch, void* output, void*, void*, + mscclpp::DeviceHandle* nvlsChannels, + mscclpp::DeviceHandle*, size_t, size_t, size_t scratchBufferSize, + int rank, int, int worldSize, size_t nelems, cudaStream_t stream, uint32_t* deviceFlag, + uint32_t*, uint32_t*, uint32_t) { + size_t size = nelems * sizeof(T); + int nBlocks = 8; + int nThreadsPerBlock = 1024; + if (size <= (1 << 13)) { + nBlocks = 4; + nThreadsPerBlock = 512; + } + allreduceNvlsPacket<<>>( + (const T*)input, (T*)scratch, (T*)output, nvlsChannels, nelems, scratchBufferSize, rank, worldSize, deviceFlag); + return cudaGetLastError(); + } +}; + template