@@ -18,23 +18,6 @@ namespace quickreduce {
1818using fptr_t = int64_t ;
1919static_assert (sizeof (void *) == sizeof (fptr_t ));
2020
21- static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12 ;
22- static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8 ;
23- static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4 ;
24- static constexpr unsigned int kOneShotAllreduceMaxSize =
25- std::max (kOneShotAllreduceMaxElemsWorldSize2 * 2 ,
26- std::max (kOneShotAllreduceMaxElemsWorldSize4 * 4 ,
27- kOneShotAllreduceMaxElemsWorldSize8 * 8 )) *
28- sizeof(half);
29-
30- template <typename AllReduceKernel, typename T>
31- __global__ __quickreduce_launch_bounds_one_shot__ static void
32- allreduce_prototype_oneshot (T const * A, T* B, uint32_t N, int rank,
33- uint8_t ** dbuffer_list, uint32_t data_offset,
34- uint32_t flag_color) {
35- AllReduceKernel::run (A, B, N, rank, dbuffer_list, data_offset, flag_color);
36- }
37-
3821template <typename AllReduceKernel, typename T>
3922__global__ __quickreduce_launch_bounds_two_shot__ static void
4023allreduce_prototype_twoshot (T const * A, T* B, uint32_t N, int num_blocks,
@@ -50,24 +33,6 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks,
5033 }
5134}
5235
53- #define ONESHOT_DISPATCH () \
54- if (world_size == 2 ) { \
55- using AllReduceKernel = AllReduceOneshot<T, 2 >; \
56- hipLaunchKernelGGL ((allreduce_prototype_oneshot<AllReduceKernel, T>), \
57- dim3 (grid), dim3 (kBlockOneShot ), 0 , stream, A, B, N, \
58- rank, dbuffer_list, data_offset, flag_color); \
59- } else if (world_size == 4 ) { \
60- using AllReduceKernel = AllReduceOneshot<T, 4 >; \
61- hipLaunchKernelGGL ((allreduce_prototype_oneshot<AllReduceKernel, T>), \
62- dim3 (grid), dim3 (kBlockOneShot ), 0 , stream, A, B, N, \
63- rank, dbuffer_list, data_offset, flag_color); \
64- } else if (world_size == 8 ) { \
65- using AllReduceKernel = AllReduceOneshot<T, 8 >; \
66- hipLaunchKernelGGL ((allreduce_prototype_oneshot<AllReduceKernel, T>), \
67- dim3 (grid), dim3 (kBlockOneShot ), 0 , stream, A, B, N, \
68- rank, dbuffer_list, data_offset, flag_color); \
69- }
70-
7136#define TWOSHOT_DISPATCH (__codec ) \
7237 if (world_size == 2 ) { \
7338 using LineCodec = __codec<T, 2 >; \
@@ -132,8 +97,7 @@ struct DeviceComms {
13297
13398 // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer.
13499 uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof (int );
135- static constexpr int64_t data_buffer_size = std::max (
136- 2 * kMaxProblemSize , static_cast <int64_t >(kOneShotAllreduceMaxSize ));
100+ static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize ;
137101 int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
138102 data_offset = flags_buffer_size;
139103 HIP_CHECK (hipExtMallocWithFlags ((void **)&dbuffer, total_buffer_size,
@@ -204,33 +168,22 @@ struct DeviceComms {
204168
205169 // Configuration.
206170 uint32_t msg_size = N * sizeof (T);
207- bool use_one_shot_allreduce =
208- (world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2 ) or
209- (world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4 ) or
210- (world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8 );
211- if (use_one_shot_allreduce) {
212- // Each thread processes blocks out of 4 elements
213- uint64_t num_blocks = divceil (N, (4 * kThreadsOneShot ));
214- uint64_t grid = min (kMaxNumBlocks , num_blocks);
215- ONESHOT_DISPATCH ()
216- } else {
217- uint64_t num_blocks = divceil (msg_size, kTileSize );
218- uint64_t grid = min (kMaxNumBlocks , num_blocks);
219- auto quant_level_ = static_cast <QuickReduceQuantLevel>(quant_level);
220- switch (quant_level_) {
221- case QuickReduceQuantLevel::INT8:
222- TWOSHOT_DISPATCH (CodecQ8)
223- break ;
224- case QuickReduceQuantLevel::INT6:
225- TWOSHOT_DISPATCH (CodecQ6)
226- break ;
227- case QuickReduceQuantLevel::INT4:
228- TWOSHOT_DISPATCH (CodecQ4)
229- break ;
230- default :
231- TWOSHOT_DISPATCH (CodecFP)
232- break ;
233- }
171+ uint64_t num_blocks = divceil (msg_size, kTileSize );
172+ uint64_t grid = min (kMaxNumBlocks , num_blocks);
173+ auto quant_level_ = static_cast <QuickReduceQuantLevel>(quant_level);
174+ switch (quant_level_) {
175+ case QuickReduceQuantLevel::INT8:
176+ TWOSHOT_DISPATCH (CodecQ8)
177+ break ;
178+ case QuickReduceQuantLevel::INT6:
179+ TWOSHOT_DISPATCH (CodecQ6)
180+ break ;
181+ case QuickReduceQuantLevel::INT4:
182+ TWOSHOT_DISPATCH (CodecQ4)
183+ break ;
184+ default :
185+ TWOSHOT_DISPATCH (CodecFP)
186+ break ;
234187 }
235188 HIP_CHECK (cudaGetLastError ());
236189 // Rotate the flag color.
0 commit comments