@@ -71,13 +71,20 @@ struct NvlsAdapter {
7171 mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsOutChannels, size_t channelInOffset,
7272 size_t channelOutOffset, size_t , int rank, int nRanksPerNode, int , size_t nelems,
7373 cudaStream_t stream, uint32_t *, uint32_t *, uint32_t *, uint32_t ) {
74- using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
75- int nBlocks = nRanksPerNode;
76- int nThreadsPerBlock = 1024 ;
77- allreduce9<T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> ((ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels,
78- channelInOffset, channelOutOffset, nelems * sizeof (T), rank,
79- nRanksPerNode);
80- return cudaGetLastError ();
74+ #if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
75+ if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
76+ return cudaErrorNotSupported;
77+ } else
78+ #endif
79+ {
80+ using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
81+ int nBlocks = nRanksPerNode;
82+ int nThreadsPerBlock = 1024 ;
83+ allreduce9<T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> ((ChannelType*)memoryChannels, nvlsChannels,
84+ nvlsOutChannels, channelInOffset, channelOutOffset,
85+ nelems * sizeof (T), rank, nRanksPerNode);
86+ return cudaGetLastError ();
87+ }
8188 }
8289};
8390
@@ -88,21 +95,28 @@ struct NvlsWithCopyAdapter {
8895 mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t , size_t , size_t scratchBufferSize,
8996 int rank, int nRanksPerNode, int , size_t nelems, cudaStream_t stream, uint32_t *, uint32_t *,
9097 uint32_t *, uint32_t ) {
91- using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
92- if (sizeof (T) * nelems < (1 << 24 )) {
93- int nBlocks = nRanksPerNode * 4 ;
94- int nThreadsPerBlock = 1024 ;
95- allreduce10<T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> (input, scratch, output, (ChannelType*)memoryChannels,
96- nvlsChannels, nelems * sizeof (T), scratchBufferSize,
97- rank, nRanksPerNode);
98- } else {
99- int nBlocks = nRanksPerNode * 5 ;
100- int nThreadsPerBlock = 1024 ;
101- allreduce11<T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> (input, scratch, output, (ChannelType*)memoryChannels,
102- nvlsChannels, nelems * sizeof (T), scratchBufferSize,
103- rank, nRanksPerNode);
98+ #if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
99+ if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
100+ return cudaErrorNotSupported;
101+ } else
102+ #endif
103+ {
104+ using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
105+ if (sizeof (T) * nelems < (1 << 24 )) {
106+ int nBlocks = nRanksPerNode * 4 ;
107+ int nThreadsPerBlock = 1024 ;
108+ allreduce10<T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> (input, scratch, output, (ChannelType*)memoryChannels,
109+ nvlsChannels, nelems * sizeof (T), scratchBufferSize,
110+ rank, nRanksPerNode);
111+ } else {
112+ int nBlocks = nRanksPerNode * 5 ;
113+ int nThreadsPerBlock = 1024 ;
114+ allreduce11<T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> (input, scratch, output, (ChannelType*)memoryChannels,
115+ nvlsChannels, nelems * sizeof (T), scratchBufferSize,
116+ rank, nRanksPerNode);
117+ }
118+ return cudaGetLastError ();
104119 }
105- return cudaGetLastError ();
106120 }
107121};
108122
@@ -154,6 +168,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
154168#if defined(__CUDA_BF16_TYPES_EXIST__)
155169 } else if (dtype == ncclBfloat16) {
156170 return Adapter<SUM, __bfloat16>::call;
171+ #endif
172+ #if defined(__FP8_TYPES_EXIST__)
173+ } else if (dtype == ncclFloat8e4m3) {
174+ return Adapter<SUM, __fp8_e4m3>::call;
175+ } else if (dtype == ncclFloat8e5m2) {
176+ return Adapter<SUM, __fp8_e5m2>::call;
157177#endif
158178 } else if (dtype == ncclInt32 || dtype == ncclUint32) {
159179 return Adapter<SUM, int >::call;
@@ -168,6 +188,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
168188#if defined(__CUDA_BF16_TYPES_EXIST__)
169189 } else if (dtype == ncclBfloat16) {
170190 return Adapter<MIN, __bfloat16>::call;
191+ #endif
192+ #if defined(__FP8_TYPES_EXIST__)
193+ } else if (dtype == ncclFloat8e4m3) {
194+ return Adapter<MIN, __fp8_e4m3>::call;
195+ } else if (dtype == ncclFloat8e5m2) {
196+ return Adapter<MIN, __fp8_e5m2>::call;
171197#endif
172198 } else if (dtype == ncclInt32 || dtype == ncclUint32) {
173199 return Adapter<MIN, int >::call;
0 commit comments