@@ -524,11 +524,11 @@ __device__ void filter_and_histogram(T const* in_buf,
524524 // merge histograms produced by individual blocks
525525 for (int i = threadIdx .x ; i < num_buckets; i += blockDim .x )
526526 {
527- // if(histogram_smem[i] != 0)
528- // {
529- // atomicAdd(histogram + i, histogram_smem[i]);
530- // }
531- *(histogram + i) = histogram_smem[i];
527+ if (histogram_smem[i] != 0 )
528+ {
529+ atomicAdd (histogram + i, histogram_smem[i]);
530+ }
531+ // *(histogram + i) = histogram_smem[i];
532532 }
533533}
534534
@@ -733,24 +733,27 @@ __global__ void last_filter_kernel(T const* in,
733733 T* out,
734734 IdxT* out_idx,
735735 IdxT len,
736+ const IdxT* rowStarts,
737+ const IdxT* rowEnds,
736738 IdxT k,
737739 Counter<T, IdxT>* counters,
738740 bool const select_min)
739741{
740742 const int64_t batch_id = blockIdx .y ; // size_t to avoid multiplication overflow
743+ const IdxT row_len = rowEnds[batch_id] - rowStarts[batch_id];
741744
742745 Counter<T, IdxT>* counter = counters + batch_id;
743746 IdxT previous_len = counter->previous_len ;
744747 if (previous_len == 0 )
745748 {
746749 return ;
747750 }
748- const IdxT buf_len = calc_buf_len<T>(len );
751+ const IdxT buf_len = calc_buf_len<T>(row_len );
749752 if (previous_len > buf_len || in_buf == in)
750753 {
751754 in_buf = in + batch_id * len;
752755 in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr ;
753- previous_len = len ;
756+ previous_len = row_len ;
754757 }
755758 else
756759 {
@@ -978,13 +981,13 @@ __global__ void radix_kernel(T const* in,
978981
979982 constexpr int num_passes = calc_num_passes<T, BitsPerPass>();
980983 // reset for next pass
981- // if(pass != num_passes - 1)
982- // {
983- // for(int i = threadIdx.x; i < num_buckets; i += blockDim.x)
984- // {
985- // histogram[i] = 0;
986- // }
987- // }
984+ if (pass != num_passes - 1 )
985+ {
986+ for (int i = threadIdx .x ; i < num_buckets; i += blockDim .x )
987+ {
988+ histogram[i] = 0 ;
989+ }
990+ }
988991 if (threadIdx .x == 0 )
989992 {
990993 // `last_filter_kernel()` requires setting previous_len even in the last
@@ -1529,17 +1532,14 @@ void standalone_stable_radix_topk_(void* buf,
15291532 T* buf2 = nullptr ;
15301533 IdxT* idx_buf2 = nullptr ;
15311534
1532- IdxT* topk_out_idx = nullptr ;
1533-
15341535 {
1535- IdxT len_candidates = calc_buf_len<T>(len);
1536+ IdxT len_candidates = calc_buf_len<T, IdxT >(len);
15361537 std::vector<size_t > sizes = {sizeof (*counters) * batch_size,
15371538 sizeof (*histograms) * num_buckets * batch_size,
15381539 sizeof (*buf1) * len_candidates * batch_size,
15391540 sizeof (*idx_buf1) * len_candidates * batch_size,
15401541 sizeof (*buf2) * len_candidates * batch_size,
1541- sizeof (*idx_buf2) * len_candidates * batch_size,
1542- sizeof (*topk_out_idx) * k * batch_size};
1542+ sizeof (*idx_buf2) * len_candidates * batch_size};
15431543
15441544 size_t total_size = calc_aligned_size (sizes);
15451545 if (!buf)
@@ -1555,7 +1555,6 @@ void standalone_stable_radix_topk_(void* buf,
15551555 idx_buf1 = static_cast <decltype (idx_buf1)>(aligned_pointers[3 ]);
15561556 buf2 = static_cast <decltype (buf2)>(aligned_pointers[4 ]);
15571557 idx_buf2 = static_cast <decltype (idx_buf2)>(aligned_pointers[5 ]);
1558- topk_out_idx = static_cast <decltype (topk_out_idx)>(aligned_pointers[6 ]);
15591558
15601559 HIP_CALL (hipMemsetAsync (aligned_pointers[0 ],
15611560 0 ,
@@ -1614,9 +1613,9 @@ void standalone_stable_radix_topk_(void* buf,
16141613
16151614 if (!fused_last_filter)
16161615 {
1617- last_filter_kernel<T, IdxT, BitsPerPass, WRITE_TOPK_VALUES, true >
1616+ last_filter_kernel<T, IdxT, BitsPerPass, WRITE_TOPK_VALUES, false >
16181617 <<<blocks, BlockSize, 0 , stream>>> (
1619- in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min);
1618+ in, in_idx, out_buf, out_idx_buf, out, out_idx, len, rowStarts, rowEnds, k, counters, select_min);
16201619 }
16211620}
16221621
@@ -1708,7 +1707,7 @@ void standalone_stable_radix_11bits(void* buf,
17081707 unsigned grid_dim =
17091708 calc_grid_dim<T, IdxT, 11 , block_dim, WRITE_TOPK_VALUES>(batch_size, len, sm_cnt);
17101709
1711- if (grid_dim == 1 )
1710+ if (1 ) // faster
17121711 {
17131712 standalone_stable_radix_topk_one_block_<T, IdxT, 11 , block_dim, WRITE_TOPK_VALUES>(
17141713 buf,
@@ -2307,7 +2306,7 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
23072306 unsigned grid_dim =
23082307 aiter::calc_grid_dim<T, IdxT, 11 , block_dim, false >(numRows, stride0, sm_cnt);
23092308
2310- if (grid_dim == 1 )
2309+ if (1 )
23112310 {
23122311 aiter::standalone_stable_radix_topk_one_block_<T, IdxT, 11 , block_dim, false >(
23132312 workspace,
0 commit comments