Skip to content

Commit c588abb

Browse files
authored
Merge pull request #95 from lihaoyang-amd/lhy_integrate_qr2cr
Integrate quick allreduce to custom allreduce
2 parents 06af4d3 + 94b6cce commit c588abb

File tree

8 files changed

+271
-517
lines changed

8 files changed

+271
-517
lines changed

csrc/quickreduce/base.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ static constexpr int kWavefront = 64;
5151
// 256 thread, 4 wavefronts.
5252
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
5353

54-
static constexpr int kThreadsOneShot = 512;
55-
static dim3 constexpr kBlockOneShot = {kThreadsOneShot, 1, 1};
56-
5754
// Number of threads in a group for quantization
5855
// It corresponds to 32 F16 elements in quantization block
5956
static constexpr int kThreadGroupSize = 8;

csrc/quickreduce/quick_reduce.h

Lines changed: 17 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,6 @@ namespace quickreduce {
1818
using fptr_t = int64_t;
1919
static_assert(sizeof(void*) == sizeof(fptr_t));
2020

21-
static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12;
22-
static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8;
23-
static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4;
24-
static constexpr unsigned int kOneShotAllreduceMaxSize =
25-
std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2,
26-
std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4,
27-
kOneShotAllreduceMaxElemsWorldSize8 * 8)) *
28-
sizeof(half);
29-
30-
template <typename AllReduceKernel, typename T>
31-
__global__ __quickreduce_launch_bounds_one_shot__ static void
32-
allreduce_prototype_oneshot(T const* A, T* B, uint32_t N, int rank,
33-
uint8_t** dbuffer_list, uint32_t data_offset,
34-
uint32_t flag_color) {
35-
AllReduceKernel::run(A, B, N, rank, dbuffer_list, data_offset, flag_color);
36-
}
37-
3821
template <typename AllReduceKernel, typename T>
3922
__global__ __quickreduce_launch_bounds_two_shot__ static void
4023
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks,
@@ -50,24 +33,6 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks,
5033
}
5134
}
5235

53-
#define ONESHOT_DISPATCH() \
54-
if (world_size == 2) { \
55-
using AllReduceKernel = AllReduceOneshot<T, 2>; \
56-
hipLaunchKernelGGL((allreduce_prototype_oneshot<AllReduceKernel, T>), \
57-
dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \
58-
rank, dbuffer_list, data_offset, flag_color); \
59-
} else if (world_size == 4) { \
60-
using AllReduceKernel = AllReduceOneshot<T, 4>; \
61-
hipLaunchKernelGGL((allreduce_prototype_oneshot<AllReduceKernel, T>), \
62-
dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \
63-
rank, dbuffer_list, data_offset, flag_color); \
64-
} else if (world_size == 8) { \
65-
using AllReduceKernel = AllReduceOneshot<T, 8>; \
66-
hipLaunchKernelGGL((allreduce_prototype_oneshot<AllReduceKernel, T>), \
67-
dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \
68-
rank, dbuffer_list, data_offset, flag_color); \
69-
}
70-
7136
#define TWOSHOT_DISPATCH(__codec) \
7237
if (world_size == 2) { \
7338
using LineCodec = __codec<T, 2>; \
@@ -132,8 +97,7 @@ struct DeviceComms {
13297

13398
// Allocate buffer size for worst case: Twoshot FP16 2-stage buffer.
13499
uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int);
135-
static constexpr int64_t data_buffer_size = std::max(
136-
2 * kMaxProblemSize, static_cast<int64_t>(kOneShotAllreduceMaxSize));
100+
static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize;
137101
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
138102
data_offset = flags_buffer_size;
139103
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
@@ -204,33 +168,22 @@ struct DeviceComms {
204168

205169
// Configuration.
206170
uint32_t msg_size = N * sizeof(T);
207-
bool use_one_shot_allreduce =
208-
(world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2) or
209-
(world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4) or
210-
(world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8);
211-
if (use_one_shot_allreduce) {
212-
// Each thread processes blocks out of 4 elements
213-
uint64_t num_blocks = divceil(N, (4 * kThreadsOneShot));
214-
uint64_t grid = min(kMaxNumBlocks, num_blocks);
215-
ONESHOT_DISPATCH()
216-
} else {
217-
uint64_t num_blocks = divceil(msg_size, kTileSize);
218-
uint64_t grid = min(kMaxNumBlocks, num_blocks);
219-
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
220-
switch (quant_level_) {
221-
case QuickReduceQuantLevel::INT8:
222-
TWOSHOT_DISPATCH(CodecQ8)
223-
break;
224-
case QuickReduceQuantLevel::INT6:
225-
TWOSHOT_DISPATCH(CodecQ6)
226-
break;
227-
case QuickReduceQuantLevel::INT4:
228-
TWOSHOT_DISPATCH(CodecQ4)
229-
break;
230-
default:
231-
TWOSHOT_DISPATCH(CodecFP)
232-
break;
233-
}
171+
uint64_t num_blocks = divceil(msg_size, kTileSize);
172+
uint64_t grid = min(kMaxNumBlocks, num_blocks);
173+
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
174+
switch (quant_level_) {
175+
case QuickReduceQuantLevel::INT8:
176+
TWOSHOT_DISPATCH(CodecQ8)
177+
break;
178+
case QuickReduceQuantLevel::INT6:
179+
TWOSHOT_DISPATCH(CodecQ6)
180+
break;
181+
case QuickReduceQuantLevel::INT4:
182+
TWOSHOT_DISPATCH(CodecQ4)
183+
break;
184+
default:
185+
TWOSHOT_DISPATCH(CodecFP)
186+
break;
234187
}
235188
HIP_CHECK(cudaGetLastError());
236189
// Rotate the flag color.

csrc/quickreduce/quick_reduce_impl.cuh

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -677,108 +677,4 @@ struct AllReduceTwoshot {
677677
}
678678
};
679679

680-
// Oneshot AllReduce
681-
template <typename T, int world_size>
682-
struct AllReduceOneshot {
683-
static_assert(sizeof(T) == 2);
684-
685-
__device__ static void run(
686-
T const* __restrict__ A, // input
687-
T* __restrict__ B, // output
688-
uint32_t const N, // number of elements
689-
uint32_t const rank, // rank index
690-
uint8_t** __restrict__ buffer_list, // communication buffers
691-
long const data_offset, // offset to start of the data buffer
692-
uint32_t flag_color) {
693-
BufferResource src_buffer(const_cast<T*>(A), N * sizeof(T));
694-
BufferResource dst_buffer(B, N * sizeof(T));
695-
696-
uint8_t* rank_buffer = buffer_list[rank];
697-
698-
const int block_size = blockDim.x;
699-
const int thread = threadIdx.x;
700-
const int block = blockIdx.x;
701-
const uint32_t problem_size = (N + 3) / 4;
702-
703-
int32x4_t tA, tB;
704-
long grid = gridDim.x;
705-
long data_stride = grid * block_size * sizeof(int32x4_t);
706-
long comm_flags0_offset = block * (world_size * sizeof(int));
707-
long comm_flags1_offset =
708-
comm_flags0_offset + grid * (world_size * sizeof(int));
709-
710-
for (int idx = block * block_size + thread; idx < problem_size;
711-
idx += grid * block_size) {
712-
// load values
713-
tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t),
714-
0, 0);
715-
716-
// Write rank data into this rank segment of every rank's communication
717-
// buffer.
718-
#pragma unroll
719-
for (int r = 0; r < world_size; r++) {
720-
int32x4_t* send_buffer = reinterpret_cast<int32x4_t*>(
721-
buffer_list[r] + data_offset + rank * data_stride +
722-
idx * sizeof(int32x4_t));
723-
__builtin_nontemporal_store(tA, send_buffer);
724-
}
725-
}
726-
727-
__syncthreads();
728-
if (thread < world_size) {
729-
int r = thread;
730-
int* peer_flag_ptr = reinterpret_cast<int*>(
731-
buffer_list[r] + comm_flags0_offset + rank * sizeof(int));
732-
__atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE);
733-
int* self_flag_ptr = reinterpret_cast<int*>(
734-
rank_buffer + comm_flags0_offset + r * sizeof(int));
735-
736-
// Wait for the flags to be set.
737-
while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) {
738-
}
739-
}
740-
__syncthreads();
741-
742-
for (int idx = block * block_size + thread; idx < problem_size;
743-
idx += grid * block_size) {
744-
{
745-
int r = 0;
746-
// Read posted data from the rank's communication buffer.
747-
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(
748-
rank_buffer + data_offset + r * data_stride +
749-
idx * sizeof(int32x4_t));
750-
tA = __builtin_nontemporal_load(recv_buffer);
751-
}
752-
#pragma unroll
753-
for (int r = 1; r < world_size; r++) {
754-
// Read posted data from the rank's communication buffer.
755-
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(
756-
rank_buffer + data_offset + r * data_stride +
757-
idx * sizeof(int32x4_t));
758-
tB = __builtin_nontemporal_load(recv_buffer);
759-
760-
// Reduce the local data with the read data
761-
packed_assign_add<T>(&tA, &tB);
762-
}
763-
764-
buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t),
765-
0, 0);
766-
}
767-
768-
__syncthreads();
769-
if (thread < world_size) {
770-
int r = thread;
771-
int* peer_flag_ptr = reinterpret_cast<int*>(
772-
buffer_list[r] + comm_flags1_offset + rank * sizeof(int));
773-
__atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED);
774-
int* self_flag_ptr = reinterpret_cast<int*>(
775-
rank_buffer + comm_flags1_offset + r * sizeof(int));
776-
777-
// Wait for the flags to be set.
778-
while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) {
779-
}
780-
}
781-
}
782-
};
783-
784680
} // namespace quickreduce

tests/distributed/test_quick_reduce.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

0 commit comments

Comments
 (0)