Skip to content

Commit eae48c3

Browse files
authored
Merge branch 'main' into test_do_cluster
2 parents a55a4b1 + 6af8b68 commit eae48c3

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

csrc/kernels/topk_per_row_kernels.cu

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -524,11 +524,11 @@ __device__ void filter_and_histogram(T const* in_buf,
524524
// merge histograms produced by individual blocks
525525
for(int i = threadIdx.x; i < num_buckets; i += blockDim.x)
526526
{
527-
// if(histogram_smem[i] != 0)
528-
// {
529-
// atomicAdd(histogram + i, histogram_smem[i]);
530-
// }
531-
*(histogram + i) = histogram_smem[i];
527+
if(histogram_smem[i] != 0)
528+
{
529+
atomicAdd(histogram + i, histogram_smem[i]);
530+
}
531+
// *(histogram + i) = histogram_smem[i];
532532
}
533533
}
534534

@@ -733,24 +733,27 @@ __global__ void last_filter_kernel(T const* in,
733733
T* out,
734734
IdxT* out_idx,
735735
IdxT len,
736+
const IdxT* rowStarts,
737+
const IdxT* rowEnds,
736738
IdxT k,
737739
Counter<T, IdxT>* counters,
738740
bool const select_min)
739741
{
740742
const int64_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow
743+
const IdxT row_len = rowEnds[batch_id] - rowStarts[batch_id];
741744

742745
Counter<T, IdxT>* counter = counters + batch_id;
743746
IdxT previous_len = counter->previous_len;
744747
if(previous_len == 0)
745748
{
746749
return;
747750
}
748-
const IdxT buf_len = calc_buf_len<T>(len);
751+
const IdxT buf_len = calc_buf_len<T>(row_len);
749752
if(previous_len > buf_len || in_buf == in)
750753
{
751754
in_buf = in + batch_id * len;
752755
in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr;
753-
previous_len = len;
756+
previous_len = row_len;
754757
}
755758
else
756759
{
@@ -978,13 +981,13 @@ __global__ void radix_kernel(T const* in,
978981

979982
constexpr int num_passes = calc_num_passes<T, BitsPerPass>();
980983
// reset for next pass
981-
// if(pass != num_passes - 1)
982-
// {
983-
// for(int i = threadIdx.x; i < num_buckets; i += blockDim.x)
984-
// {
985-
// histogram[i] = 0;
986-
// }
987-
// }
984+
if(pass != num_passes - 1)
985+
{
986+
for(int i = threadIdx.x; i < num_buckets; i += blockDim.x)
987+
{
988+
histogram[i] = 0;
989+
}
990+
}
988991
if(threadIdx.x == 0)
989992
{
990993
// `last_filter_kernel()` requires setting previous_len even in the last
@@ -1529,17 +1532,14 @@ void standalone_stable_radix_topk_(void* buf,
15291532
T* buf2 = nullptr;
15301533
IdxT* idx_buf2 = nullptr;
15311534

1532-
IdxT* topk_out_idx = nullptr;
1533-
15341535
{
1535-
IdxT len_candidates = calc_buf_len<T>(len);
1536+
IdxT len_candidates = calc_buf_len<T, IdxT>(len);
15361537
std::vector<size_t> sizes = {sizeof(*counters) * batch_size,
15371538
sizeof(*histograms) * num_buckets * batch_size,
15381539
sizeof(*buf1) * len_candidates * batch_size,
15391540
sizeof(*idx_buf1) * len_candidates * batch_size,
15401541
sizeof(*buf2) * len_candidates * batch_size,
1541-
sizeof(*idx_buf2) * len_candidates * batch_size,
1542-
sizeof(*topk_out_idx) * k * batch_size};
1542+
sizeof(*idx_buf2) * len_candidates * batch_size};
15431543

15441544
size_t total_size = calc_aligned_size(sizes);
15451545
if(!buf)
@@ -1555,7 +1555,6 @@ void standalone_stable_radix_topk_(void* buf,
15551555
idx_buf1 = static_cast<decltype(idx_buf1)>(aligned_pointers[3]);
15561556
buf2 = static_cast<decltype(buf2)>(aligned_pointers[4]);
15571557
idx_buf2 = static_cast<decltype(idx_buf2)>(aligned_pointers[5]);
1558-
topk_out_idx = static_cast<decltype(topk_out_idx)>(aligned_pointers[6]);
15591558

15601559
HIP_CALL(hipMemsetAsync(aligned_pointers[0],
15611560
0,
@@ -1614,9 +1613,9 @@ void standalone_stable_radix_topk_(void* buf,
16141613

16151614
if(!fused_last_filter)
16161615
{
1617-
last_filter_kernel<T, IdxT, BitsPerPass, WRITE_TOPK_VALUES, true>
1616+
last_filter_kernel<T, IdxT, BitsPerPass, WRITE_TOPK_VALUES, false>
16181617
<<<blocks, BlockSize, 0, stream>>>(
1619-
in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min);
1618+
in, in_idx, out_buf, out_idx_buf, out, out_idx, len, rowStarts, rowEnds, k, counters, select_min);
16201619
}
16211620
}
16221621

@@ -1708,7 +1707,7 @@ void standalone_stable_radix_11bits(void* buf,
17081707
unsigned grid_dim =
17091708
calc_grid_dim<T, IdxT, 11, block_dim, WRITE_TOPK_VALUES>(batch_size, len, sm_cnt);
17101709

1711-
if(grid_dim == 1)
1710+
if(1) // faster
17121711
{
17131712
standalone_stable_radix_topk_one_block_<T, IdxT, 11, block_dim, WRITE_TOPK_VALUES>(
17141713
buf,
@@ -2307,7 +2306,7 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
23072306
unsigned grid_dim =
23082307
aiter::calc_grid_dim<T, IdxT, 11, block_dim, false>(numRows, stride0, sm_cnt);
23092308

2310-
if(grid_dim == 1)
2309+
if(1)
23112310
{
23122311
aiter::standalone_stable_radix_topk_one_block_<T, IdxT, 11, block_dim, false>(
23132312
workspace,

0 commit comments

Comments
 (0)