Skip to content

Commit a38c2ee

Browse files
FP8 support for Allreduce (#646)
Add FP8 support for Allreduce on both NVIDIA and AMD platform. Add new data type: fp8_e4m3 and fp8_e5m2 --------- Co-authored-by: Binyang Li <[email protected]>
1 parent fc0aaaf commit a38c2ee

File tree

11 files changed

+680
-59
lines changed

11 files changed

+680
-59
lines changed

apps/nccl/include/mscclpp/nccl.h

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,10 @@ typedef enum {
248248
ncclFloat = 7,
249249
ncclFloat64 = 8,
250250
ncclDouble = 8,
251-
#if defined(__CUDA_BF16_TYPES_EXIST__) && defined(__CUDA_FP8_TYPES_EXIST__)
252251
ncclBfloat16 = 9,
253-
ncclFp8E4M3 = 10,
254-
ncclFp8E5M2 = 11,
252+
ncclFloat8e4m3 = 10,
253+
ncclFloat8e5m2 = 11,
255254
ncclNumTypes = 12
256-
#elif defined(__CUDA_BF16_TYPES_EXIST__)
257-
ncclBfloat16 = 9,
258-
ncclNumTypes = 10
259-
#else
260-
ncclNumTypes = 9
261-
#endif
262255
} ncclDataType_t;
263256

264257
static inline size_t ncclTypeSize(ncclDataType_t type) {
@@ -278,15 +271,11 @@ static inline size_t ncclTypeSize(ncclDataType_t type) {
278271
return 4;
279272
case ncclFloat64:
280273
return 8;
281-
#if defined(__CUDA_BF16_TYPES_EXIST__)
282274
case ncclBfloat16:
283275
return 2;
284-
#endif // defined(__CUDA_BF16_TYPES_EXIST__)
285-
#if defined(__CUDA_FP8_TYPES_EXIST__)
286-
case ncclFp8E4M3:
287-
case ncclFp8E5M2:
276+
case ncclFloat8e4m3:
277+
case ncclFloat8e5m2:
288278
return 1;
289-
#endif // defined(__CUDA_FP8_TYPES_EXIST__)
290279
case ncclNumTypes:
291280
return 0;
292281
}

apps/nccl/src/allreduce.cu

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,20 @@ struct NvlsAdapter {
7171
mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsOutChannels, size_t channelInOffset,
7272
size_t channelOutOffset, size_t, int rank, int nRanksPerNode, int, size_t nelems,
7373
cudaStream_t stream, uint32_t*, uint32_t*, uint32_t*, uint32_t) {
74-
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
75-
int nBlocks = nRanksPerNode;
76-
int nThreadsPerBlock = 1024;
77-
allreduce9<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>((ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels,
78-
channelInOffset, channelOutOffset, nelems * sizeof(T), rank,
79-
nRanksPerNode);
80-
return cudaGetLastError();
74+
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
75+
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
76+
return cudaErrorNotSupported;
77+
} else
78+
#endif
79+
{
80+
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
81+
int nBlocks = nRanksPerNode;
82+
int nThreadsPerBlock = 1024;
83+
allreduce9<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>((ChannelType*)memoryChannels, nvlsChannels,
84+
nvlsOutChannels, channelInOffset, channelOutOffset,
85+
nelems * sizeof(T), rank, nRanksPerNode);
86+
return cudaGetLastError();
87+
}
8188
}
8289
};
8390

@@ -88,21 +95,28 @@ struct NvlsWithCopyAdapter {
8895
mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t scratchBufferSize,
8996
int rank, int nRanksPerNode, int, size_t nelems, cudaStream_t stream, uint32_t*, uint32_t*,
9097
uint32_t*, uint32_t) {
91-
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
92-
if (sizeof(T) * nelems < (1 << 24)) {
93-
int nBlocks = nRanksPerNode * 4;
94-
int nThreadsPerBlock = 1024;
95-
allreduce10<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
96-
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
97-
rank, nRanksPerNode);
98-
} else {
99-
int nBlocks = nRanksPerNode * 5;
100-
int nThreadsPerBlock = 1024;
101-
allreduce11<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
102-
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
103-
rank, nRanksPerNode);
98+
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
99+
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
100+
return cudaErrorNotSupported;
101+
} else
102+
#endif
103+
{
104+
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
105+
if (sizeof(T) * nelems < (1 << 24)) {
106+
int nBlocks = nRanksPerNode * 4;
107+
int nThreadsPerBlock = 1024;
108+
allreduce10<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
109+
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
110+
rank, nRanksPerNode);
111+
} else {
112+
int nBlocks = nRanksPerNode * 5;
113+
int nThreadsPerBlock = 1024;
114+
allreduce11<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
115+
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
116+
rank, nRanksPerNode);
117+
}
118+
return cudaGetLastError();
104119
}
105-
return cudaGetLastError();
106120
}
107121
};
108122

@@ -154,6 +168,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
154168
#if defined(__CUDA_BF16_TYPES_EXIST__)
155169
} else if (dtype == ncclBfloat16) {
156170
return Adapter<SUM, __bfloat16>::call;
171+
#endif
172+
#if defined(__FP8_TYPES_EXIST__)
173+
} else if (dtype == ncclFloat8e4m3) {
174+
return Adapter<SUM, __fp8_e4m3>::call;
175+
} else if (dtype == ncclFloat8e5m2) {
176+
return Adapter<SUM, __fp8_e5m2>::call;
157177
#endif
158178
} else if (dtype == ncclInt32 || dtype == ncclUint32) {
159179
return Adapter<SUM, int>::call;
@@ -168,6 +188,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
168188
#if defined(__CUDA_BF16_TYPES_EXIST__)
169189
} else if (dtype == ncclBfloat16) {
170190
return Adapter<MIN, __bfloat16>::call;
191+
#endif
192+
#if defined(__FP8_TYPES_EXIST__)
193+
} else if (dtype == ncclFloat8e4m3) {
194+
return Adapter<MIN, __fp8_e4m3>::call;
195+
} else if (dtype == ncclFloat8e5m2) {
196+
return Adapter<MIN, __fp8_e5m2>::call;
171197
#endif
172198
} else if (dtype == ncclInt32 || dtype == ncclUint32) {
173199
return Adapter<MIN, int>::call;

0 commit comments

Comments
 (0)