Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 1 addition & 65 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -402,68 +402,6 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
}
}

template <typename offset_t, typename index_t>
__global__
__launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel_vec(
const offset_t* const length_data,
const offset_t* const offset_data,
offset_t* const bucketized_offsets_data,
const index_t* const bucket_mapping_data,
index_t* const bucketized_permute_data_out,
int32_t lengths_size) {
uint32_t threads_per_group = 1;
if (sizeof(offset_t) == 8 && sizeof(index_t) == 8) {
threads_per_group = 4;
}
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
uint32_t idx_within_group = tid % threads_per_group; // tid & 0x3;
uint32_t group_idx = tid / threads_per_group; // tid & 0xfffffffc;
for (uint32_t b_t = group_idx; b_t < lengths_size;
b_t += (gridDim.x * blockDim.x / threads_per_group)) {
const auto length = length_data[b_t];
const auto offset = offset_data[b_t];

// Check if we should use 8-byte atomic operations for this data type
// combination
if (sizeof(offset_t) == 8 && sizeof(index_t) == 8) {
// Check alignment once per b_t outside the inner loop
offset_t* const base_offset_ptr = &bucketized_offsets_data[b_t];
bool is_aligned = (reinterpret_cast<uintptr_t>(base_offset_ptr) %
alignof(unsigned long long int)) == 0;
if (is_aligned) {
for (uint32_t i = idx_within_group; i < length;
i += threads_per_group) {
const auto index = offset + i;
const auto bucket = bucket_mapping_data[index];
unsigned long long int* const p =
reinterpret_cast<unsigned long long int*>(
&bucketized_offsets_data[bucket * lengths_size + b_t]);
bucketized_permute_data_out[index] = atomicAdd(p, 1llu);
}
} else {
// Fall back to non-atomic increment if misaligned
for (uint32_t i = idx_within_group; i < length;
i += threads_per_group) {
const auto index = offset + i;
const auto bucket = bucket_mapping_data[index];
offset_t* const offset_ptr =
&bucketized_offsets_data[bucket * lengths_size + b_t];
bucketized_permute_data_out[index] = offset_ptr[0];
offset_ptr[0]++;
}
}
} else {
// 4-byte operations - no alignment check needed
for (uint32_t i = idx_within_group; i < length; i += threads_per_group) {
const auto index = offset + i;
const auto bucket = bucket_mapping_data[index];
bucketized_permute_data_out[index] =
bucketized_offsets_data[bucket * lengths_size + b_t]++;
}
}
}
}

#define LAUNCH_BLOCK_BUCKETIZE_SEQUENCE_SPARSE_FEATURES_CUDA_KERNEL_WITH_WEIGHT( \
bucketize_pos, return_bucket_mapping) \
AT_DISPATCH_INDEX_TYPES( \
Expand Down Expand Up @@ -1104,9 +1042,7 @@ DLL_PUBLIC Tensor populate_bucketized_permute_cuda(
"_populate_bucketized_permute_cuda_kernel2",
[&] {
FBGEMM_LAUNCH_KERNEL(
(_populate_bucketized_permute_cuda_kernel_vec<
offset_t,
index_t>),
(_populate_bucketized_permute_cuda_kernel<offset_t, index_t>),
num_blocks,
threads_per_block,
0,
Expand Down
Loading