Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions csrc/quickreduce/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ static constexpr int kWavefront = 64;
// 256 thread, 4 wavefronts.
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};

static constexpr int kThreadsOneShot = 512;
static dim3 constexpr kBlockOneShot = {kThreadsOneShot, 1, 1};

// Number of threads in a group for quantization
// It corresponds to 32 F16 elements in quantization block
static constexpr int kThreadGroupSize = 8;
Expand Down
81 changes: 17 additions & 64 deletions csrc/quickreduce/quick_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@ namespace quickreduce {
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));

static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12;
static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8;
static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4;
static constexpr unsigned int kOneShotAllreduceMaxSize =
std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2,
std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4,
kOneShotAllreduceMaxElemsWorldSize8 * 8)) *
sizeof(half);

template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_one_shot__ static void
allreduce_prototype_oneshot(T const* A, T* B, uint32_t N, int rank,
uint8_t** dbuffer_list, uint32_t data_offset,
uint32_t flag_color) {
AllReduceKernel::run(A, B, N, rank, dbuffer_list, data_offset, flag_color);
}

template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks,
Expand All @@ -50,24 +33,6 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks,
}
}

#define ONESHOT_DISPATCH() \
if (world_size == 2) { \
using AllReduceKernel = AllReduceOneshot<T, 2>; \
hipLaunchKernelGGL((allreduce_prototype_oneshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \
rank, dbuffer_list, data_offset, flag_color); \
} else if (world_size == 4) { \
using AllReduceKernel = AllReduceOneshot<T, 4>; \
hipLaunchKernelGGL((allreduce_prototype_oneshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \
rank, dbuffer_list, data_offset, flag_color); \
} else if (world_size == 8) { \
using AllReduceKernel = AllReduceOneshot<T, 8>; \
hipLaunchKernelGGL((allreduce_prototype_oneshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \
rank, dbuffer_list, data_offset, flag_color); \
}

#define TWOSHOT_DISPATCH(__codec) \
if (world_size == 2) { \
using LineCodec = __codec<T, 2>; \
Expand Down Expand Up @@ -132,8 +97,7 @@ struct DeviceComms {

// Allocate buffer size for worst case: Twoshot FP16 2-stage buffer.
uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int);
static constexpr int64_t data_buffer_size = std::max(
2 * kMaxProblemSize, static_cast<int64_t>(kOneShotAllreduceMaxSize));
static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize;
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
data_offset = flags_buffer_size;
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
Expand Down Expand Up @@ -204,33 +168,22 @@ struct DeviceComms {

// Configuration.
uint32_t msg_size = N * sizeof(T);
bool use_one_shot_allreduce =
(world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2) or
(world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4) or
(world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8);
if (use_one_shot_allreduce) {
// Each thread processes blocks out of 4 elements
uint64_t num_blocks = divceil(N, (4 * kThreadsOneShot));
uint64_t grid = min(kMaxNumBlocks, num_blocks);
ONESHOT_DISPATCH()
} else {
uint64_t num_blocks = divceil(msg_size, kTileSize);
uint64_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
default:
TWOSHOT_DISPATCH(CodecFP)
break;
}
uint64_t num_blocks = divceil(msg_size, kTileSize);
uint64_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
default:
TWOSHOT_DISPATCH(CodecFP)
break;
}
HIP_CHECK(cudaGetLastError());
// Rotate the flag color.
Expand Down
104 changes: 0 additions & 104 deletions csrc/quickreduce/quick_reduce_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -677,108 +677,4 @@ struct AllReduceTwoshot {
}
};

// Oneshot AllReduce
template <typename T, int world_size>
struct AllReduceOneshot {
static_assert(sizeof(T) == 2);

__device__ static void run(
T const* __restrict__ A, // input
T* __restrict__ B, // output
uint32_t const N, // number of elements
uint32_t const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
long const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
BufferResource src_buffer(const_cast<T*>(A), N * sizeof(T));
BufferResource dst_buffer(B, N * sizeof(T));

uint8_t* rank_buffer = buffer_list[rank];

const int block_size = blockDim.x;
const int thread = threadIdx.x;
const int block = blockIdx.x;
const uint32_t problem_size = (N + 3) / 4;

int32x4_t tA, tB;
long grid = gridDim.x;
long data_stride = grid * block_size * sizeof(int32x4_t);
long comm_flags0_offset = block * (world_size * sizeof(int));
long comm_flags1_offset =
comm_flags0_offset + grid * (world_size * sizeof(int));

for (int idx = block * block_size + thread; idx < problem_size;
idx += grid * block_size) {
// load values
tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t),
0, 0);

// Write rank data into this rank segment of every rank's communication
// buffer.
#pragma unroll
for (int r = 0; r < world_size; r++) {
int32x4_t* send_buffer = reinterpret_cast<int32x4_t*>(
buffer_list[r] + data_offset + rank * data_stride +
idx * sizeof(int32x4_t));
__builtin_nontemporal_store(tA, send_buffer);
}
}

__syncthreads();
if (thread < world_size) {
int r = thread;
int* peer_flag_ptr = reinterpret_cast<int*>(
buffer_list[r] + comm_flags0_offset + rank * sizeof(int));
__atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE);
int* self_flag_ptr = reinterpret_cast<int*>(
rank_buffer + comm_flags0_offset + r * sizeof(int));

// Wait for the flags to be set.
while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) {
}
}
__syncthreads();

for (int idx = block * block_size + thread; idx < problem_size;
idx += grid * block_size) {
{
int r = 0;
// Read posted data from the rank's communication buffer.
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(
rank_buffer + data_offset + r * data_stride +
idx * sizeof(int32x4_t));
tA = __builtin_nontemporal_load(recv_buffer);
}
#pragma unroll
for (int r = 1; r < world_size; r++) {
// Read posted data from the rank's communication buffer.
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(
rank_buffer + data_offset + r * data_stride +
idx * sizeof(int32x4_t));
tB = __builtin_nontemporal_load(recv_buffer);

// Reduce the local data with the read data
packed_assign_add<T>(&tA, &tB);
}

buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t),
0, 0);
}

__syncthreads();
if (thread < world_size) {
int r = thread;
int* peer_flag_ptr = reinterpret_cast<int*>(
buffer_list[r] + comm_flags1_offset + rank * sizeof(int));
__atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED);
int* self_flag_ptr = reinterpret_cast<int*>(
rank_buffer + comm_flags1_offset + r * sizeof(int));

// Wait for the flags to be set.
while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) {
}
}
}
};

} // namespace quickreduce
127 changes: 0 additions & 127 deletions tests/distributed/test_quick_reduce.py

This file was deleted.

Loading