Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions csrc/custom_quickreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) {
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size == 6)
throw std::invalid_argument("world size == 6 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
Expand All @@ -20,9 +22,11 @@ quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) {
}

void qr_destroy(quickreduce::fptr_t _fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
if (_fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
}
}

torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
Expand Down
165 changes: 37 additions & 128 deletions csrc/quickreduce/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ __quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
}
#endif
}
union bf162_int_union {
int i;
nv_bfloat162 bf2;
};

template <typename T>
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A,
Expand Down Expand Up @@ -152,10 +156,11 @@ __quickreduce_device_inline__ int packed_max<half>(int a, int b) {

template <>
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hmax2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmax2(A.bf2, B.bf2);
return R.i;
}

template <typename T>
Expand All @@ -170,10 +175,11 @@ __quickreduce_device_inline__ int packed_min<half>(int a, int b) {

template <>
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hmin2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmin2(A.bf2, B.bf2);
return R.i;
}

template <typename T>
Expand All @@ -194,15 +200,12 @@ __quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {

template <>
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
nv_bfloat162 wmaxh2 = *(reinterpret_cast<nv_bfloat162*>(&a));
nv_bfloat162 wminh2 = *(reinterpret_cast<nv_bfloat162*>(&b));
nv_bfloat162 wblockmaxh2;
wblockmaxh2.x =
__hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
wblockmaxh2.y =
__hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;

return *(reinterpret_cast<int*>(&wblockmaxh2));
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
return R.i;
}

template <typename T>
Expand All @@ -217,10 +220,11 @@ __quickreduce_device_inline__ int packed_add<half>(int a, int b) {

template <>
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hadd2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hadd2(A.bf2, B.bf2);
return R.i;
}

template <>
Expand All @@ -246,10 +250,11 @@ __quickreduce_device_inline__ int packed_sub<half>(int a, int b) {

template <>
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hsub2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hsub2(A.bf2, B.bf2);
return R.i;
}

template <typename T>
Expand Down Expand Up @@ -280,78 +285,21 @@ __quickreduce_device_inline__ int packed_rcp<half>(int a) {

template <>
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162 tR = h2rcp(*tA);
return *(reinterpret_cast<int*>(&tR));
bf162_int_union A, R;
A.i = a;
R.bf2 = h2rcp(A.bf2);
return R.i;
}

template <typename T>
__quickreduce_device_inline__ T float2T_cast(float a);

template <>
__quickreduce_device_inline__ half float2T_cast<half>(float a) {
return __float2half(a);
}

template <>
__quickreduce_device_inline__ nv_bfloat16 float2T_cast<nv_bfloat16>(float a) {
return __float2bfloat16(a);
}

template <typename T>
__quickreduce_device_inline__ float T2float_cast(T a);

template <>
__quickreduce_device_inline__ float T2float_cast<half>(half a) {
// changes dtype
__quickreduce_device_inline__ float T2float_cast(half a) {
return __half2float(a);
}

template <>
__quickreduce_device_inline__ float T2float_cast<nv_bfloat16>(nv_bfloat16 a) {
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
return __bfloat162float(a);
}

template <typename T>
__quickreduce_device_inline__ unsigned char T2uchar_cast(T a);

template <>
__quickreduce_device_inline__ unsigned char T2uchar_cast<half>(half a) {
return static_cast<unsigned char>(__half2ushort_rz(a));
}

template <>
__quickreduce_device_inline__ unsigned char T2uchar_cast<nv_bfloat16>(
nv_bfloat16 a) {
return static_cast<unsigned char>(__bfloat16_as_ushort(a));
}

template <typename T>
__quickreduce_device_inline__ T uchar2T_cast(unsigned char a);

template <>
__quickreduce_device_inline__ half uchar2T_cast<half>(unsigned char a) {
return __ushort2half_rz(static_cast<unsigned short>(a));
}

template <>
__quickreduce_device_inline__ nv_bfloat16
uchar2T_cast<nv_bfloat16>(unsigned char a) {
return __ushort_as_bfloat16(static_cast<unsigned short>(a));
}

template <typename T>
__quickreduce_device_inline__ int T2int_cast(T a);

template <>
__quickreduce_device_inline__ int T2int_cast<half>(half a) {
return __half2int_rz(a);
}

template <>
__quickreduce_device_inline__ int T2int_cast<nv_bfloat16>(nv_bfloat16 a) {
return static_cast<int>(__bfloat16_as_ushort(a));
}

template <typename T>
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
Expand Down Expand Up @@ -384,45 +332,6 @@ __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
return wblockmax;
}

template <typename T>
__quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& wblockmax,
int& wblockmin,
int valid_data) {
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
static constexpr int FP_MAX =
std::is_same<T, half>::value ? 0x7BFF7BFF : 0x7F7F7F7F;
static constexpr int FP_MIN =
std::is_same<T, half>::value ? 0xFBFFFBFF : 0xFF7FFF7F;

int wmax, wmin;
int a, b;
a = packed_max<T>(atom[0], atom[1]);
b = packed_max<T>(atom[2], atom[3]);
// In case the data was loaded out of range (and initialized to 0)
// we set max min values to sentinel values
// so that they do not spoil the group max min values
wmax = valid_data * packed_max<T>(a, b) + (!valid_data) * FP_MIN;

a = packed_min<T>(atom[0], atom[1]);
b = packed_min<T>(atom[2], atom[3]);
wmin = valid_data * packed_min<T>(a, b) + (!valid_data) * FP_MAX;

// Reduce the max and min among a group of threads
// Note: This is basically 2 blocks of values setup as the
// upper/lower halves of the f16x2_t
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
int x = __shfl_down(wmax, i);
wmax = packed_max<T>(wmax, x);

int y = __shfl_down(wmin, i);
wmin = packed_min<T>(wmin, y);
}

// Share with the cohort
wblockmax = __shfl(wmax, group_leader);
wblockmin = __shfl(wmin, group_leader);
}

__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr,
uint32_t flag) {
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
Expand Down
10 changes: 7 additions & 3 deletions csrc/quickreduce/quick_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,17 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks,

enum QuickReduceQuantLevel {
FP16 = 0,
INT8,
INT4,
INT8 = 1,
INT6 = 2,
INT4 = 3,
};

struct DeviceComms {
// Workgroup scope = Tile = (256 threads x 16B x 8 atoms)
static long constexpr kTileSize = 256 * 16 * 8;

// Max problem size is 2GB (in bytes) or half of uint32_t max value.
static int64_t constexpr kMaxProblemSize = 2147483647;
static int64_t constexpr kMaxProblemSize = 2147483648;
static int64_t constexpr kMaxTiles = kMaxProblemSize / kTileSize;

// Max TP-8
Expand Down Expand Up @@ -220,6 +221,9 @@ struct DeviceComms {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
Expand Down
Loading