Skip to content

Commit be1b514

Browse files
avbokovoymeta-codesync[bot]
authored andcommitted
group_index_select_or_add_2d_kernel forward pass optimization (#5080)
Summary: Pull Request resolved: #5080 X-link: https://github.com/facebookresearch/FBGEMM/pull/2087 This PR introduces optimization for `group_index_select_or_add_2d_kernel` (`USE_INDEX_SELECT==true`) kernel with primary focus on `float` type and relatively small embedding dimensions. 2 things are implemented: 1) Extracted the common variables out of the loop to omit unnecessary synchronizations on memory load (compiler won't do that automatically) 2) Switch to 32 threads logical wave sizes to reduce granularity losses. Pull Request resolved: #5078 Reviewed By: spcyppt, haoyuz Differential Revision: D86135611 Pulled By: q10 fbshipit-source-id: f4fb9966f5f5180c4dde2aed92ca726c260b7743
1 parent f1267c4 commit be1b514

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

fbgemm_gpu/src/sparse_ops/sparse_group_index.cu

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@ using Tensor = at::Tensor;
1212

1313
namespace fbgemm_gpu {
1414

15+
#ifdef USE_ROCM
16+
// The wave size is forced to be 32 on ROCm devices in favor
17+
// of granularity losses reduction.
18+
constexpr int EMULATED_WARP_SIZE = 32;
19+
#else
20+
constexpr int EMULATED_WARP_SIZE = kWarpSize;
21+
#endif
22+
1523
// TODO: Update UNROLL_FACTOR
1624
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
1725
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
18-
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
26+
GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
1927

2028
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2129
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -43,12 +51,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4351
const int64_t num_work_rows, // number of rows to work on per member
4452
const int64_t group_size) {
4553
const auto total_num_warps = warp_offsets_group[group_size];
54+
int32_t num_cols = 0;
55+
int32_t warps_per_row = 0;
56+
57+
if constexpr (!USE_VAR_COLS) {
58+
num_cols = num_cols_group[0];
59+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
60+
}
61+
4662
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
4763
warp_id < total_num_warps;
4864
warp_id += gridDim.x * blockDim.y) {
49-
int32_t member_id, member_warp_id, num_cols, warps_per_row;
50-
if (USE_VAR_COLS) {
51-
__shared__ int member_ids[kMaxThreads / kWarpSize];
65+
int32_t member_id = 0;
66+
int32_t member_warp_id = 0;
67+
if constexpr (USE_VAR_COLS) {
68+
__shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
5269
if (threadIdx.x == 0) {
5370
binary_search_range(
5471
&member_ids[threadIdx.y],
@@ -63,8 +80,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
6380
member_warp_id = warp_id - warp_offsets_group[member_id];
6481
} else {
6582
// All columns are the same
66-
num_cols = num_cols_group[0];
67-
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
6883
member_id = warp_id / (warps_per_row * num_work_rows);
6984
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
7085
}
@@ -82,7 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
8297
#pragma unroll
8398
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
8499
// Compile time conditional
85-
if (USE_INDEX_SELECT) {
100+
if constexpr (USE_INDEX_SELECT) {
86101
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
87102
} else {
88103
gpuAtomicAddNoReturn(
@@ -113,13 +128,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
113128
at::cuda::OptionalCUDAGuard device_guard(device);
114129

115130
// Partition work based on num_work_rows
116-
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
131+
uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
117132
uint32_t max_grid_size =
118133
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
119134
uint32_t grid_size = std::min(
120135
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
121136
max_grid_size);
122-
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
137+
dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
123138

124139
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
125140
FBGEMM_LAUNCH_KERNEL( \

0 commit comments

Comments
 (0)