@@ -402,68 +402,6 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
402402 }
403403}
404404
405- template <typename offset_t , typename index_t >
406- __global__
407- __launch_bounds__ (kMaxThreads ) void _populate_bucketized_permute_cuda_kernel_vec (
408- const offset_t * const length_data,
409- const offset_t * const offset_data,
410- offset_t * const bucketized_offsets_data,
411- const index_t * const bucket_mapping_data,
412- index_t * const bucketized_permute_data_out,
413- int32_t lengths_size) {
414- uint32_t threads_per_group = 1 ;
415- if (sizeof (offset_t ) == 8 && sizeof (index_t ) == 8 ) {
416- threads_per_group = 4 ;
417- }
418- uint32_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
419- uint32_t idx_within_group = tid % threads_per_group; // tid & 0x3;
420- uint32_t group_idx = tid / threads_per_group; // tid & 0xfffffffc;
421- for (uint32_t b_t = group_idx; b_t < lengths_size;
422- b_t += (gridDim .x * blockDim .x / threads_per_group)) {
423- const auto length = length_data[b_t ];
424- const auto offset = offset_data[b_t ];
425-
426- // Check if we should use 8-byte atomic operations for this data type
427- // combination
428- if (sizeof (offset_t ) == 8 && sizeof (index_t ) == 8 ) {
429- // Check alignment once per b_t outside the inner loop
430- offset_t * const base_offset_ptr = &bucketized_offsets_data[b_t ];
431- bool is_aligned = (reinterpret_cast <uintptr_t >(base_offset_ptr) %
432- alignof (unsigned long long int )) == 0 ;
433- if (is_aligned) {
434- for (uint32_t i = idx_within_group; i < length;
435- i += threads_per_group) {
436- const auto index = offset + i;
437- const auto bucket = bucket_mapping_data[index];
438- unsigned long long int * const p =
439- reinterpret_cast <unsigned long long int *>(
440- &bucketized_offsets_data[bucket * lengths_size + b_t ]);
441- bucketized_permute_data_out[index] = atomicAdd (p, 1llu);
442- }
443- } else {
444- // Fall back to non-atomic increment if misaligned
445- for (uint32_t i = idx_within_group; i < length;
446- i += threads_per_group) {
447- const auto index = offset + i;
448- const auto bucket = bucket_mapping_data[index];
449- offset_t * const offset_ptr =
450- &bucketized_offsets_data[bucket * lengths_size + b_t ];
451- bucketized_permute_data_out[index] = offset_ptr[0 ];
452- offset_ptr[0 ]++;
453- }
454- }
455- } else {
456- // 4-byte operations - no alignment check needed
457- for (uint32_t i = idx_within_group; i < length; i += threads_per_group) {
458- const auto index = offset + i;
459- const auto bucket = bucket_mapping_data[index];
460- bucketized_permute_data_out[index] =
461- bucketized_offsets_data[bucket * lengths_size + b_t ]++;
462- }
463- }
464- }
465- }
466-
467405#define LAUNCH_BLOCK_BUCKETIZE_SEQUENCE_SPARSE_FEATURES_CUDA_KERNEL_WITH_WEIGHT ( \
468406 bucketize_pos, return_bucket_mapping) \
469407 AT_DISPATCH_INDEX_TYPES ( \
@@ -1104,9 +1042,7 @@ DLL_PUBLIC Tensor populate_bucketized_permute_cuda(
11041042 " _populate_bucketized_permute_cuda_kernel2" ,
11051043 [&] {
11061044 FBGEMM_LAUNCH_KERNEL (
1107- (_populate_bucketized_permute_cuda_kernel_vec<
1108- offset_t ,
1109- index_t >),
1045+ (_populate_bucketized_permute_cuda_kernel<offset_t , index_t >),
11101046 num_blocks,
11111047 threads_per_block,
11121048 0 ,
0 commit comments