Skip to content

Commit 15d7233

Browse files
committed
SWDEV-567545 - Implement block_rank in co-op grid groups
1 parent 7be09f1 commit 15d7233

File tree

6 files changed

+47
-6
lines changed

6 files changed

+47
-6
lines changed

projects/clr/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ Full documentation for HIP is available at [rocm.docs.amd.com](https://rocm.docs
99
* New HIP APIs
1010
- `hipKernelGetParamInfo` returns the offset and size of a kernel parameter
1111

12+
* New HIP supports
13+
- `grid_group::block_rank()` returns the rank of the block in the calling thread
14+
1215
## HIP 7.2 for ROCm 7.2
1316

1417
### Added

projects/clr/hipamd/include/hip/amd_detail/amd_hip_cooperative_groups.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class thread_group {
9696
__CG_QUALIFIER__ unsigned int cg_type() const { return _type; }
9797
//! Rank of the calling thread within [0, \link num_threads() num_threads() \endlink).
9898
__CG_QUALIFIER__ __hip_uint32_t thread_rank() const;
99-
//! Rank of the calling block within [0, \link num_threads() num_threads() \endlink).
99+
//! Rank of the block in calling thread within [0, \link num_threads() num_threads() \endlink).
100100
__CG_QUALIFIER__ __hip_uint32_t block_rank() const;
101101
//! Returns true if the group has not violated any API constraints.
102102
__CG_QUALIFIER__ bool is_valid() const;
@@ -159,10 +159,6 @@ class multi_grid_group : public thread_group {
159159
__CG_QUALIFIER__ __hip_uint32_t thread_rank() const {
160160
return internal::multi_grid::thread_rank();
161161
}
162-
//! @copydoc thread_group::block_rank
163-
__CG_QUALIFIER__ __hip_uint32_t block_rank() {
164-
return internal::workgroup::block_rank();
165-
}
166162
//! @copydoc thread_group::is_valid
167163
__CG_QUALIFIER__ bool is_valid() const { return internal::multi_grid::is_valid(); }
168164
//! @copydoc thread_group::sync
@@ -365,7 +361,6 @@ class tiled_group : public thread_group {
365361
__CG_QUALIFIER__ unsigned int thread_rank() const {
366362
return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.num_threads - 1));
367363
}
368-
369364
//! @copydoc thread_group::sync
370365
__CG_QUALIFIER__ void sync() const { internal::tiled_group::sync(); }
371366
};

projects/clr/hipamd/include/hip/amd_detail/hip_cooperative_groups_helper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ __CG_STATIC_QUALIFIER__ __hip_uint32_t thread_rank() {
182182
return (num_threads_till_current_workgroup + local_thread_rank);
183183
}
184184

185+
__CG_STATIC_QUALIFIER__ __hip_uint32_t block_rank() {
186+
return static_cast<__hip_uint32_t>((blockIdx.z * gridDim.y * gridDim.x) +
187+
(blockIdx.y * gridDim.x) + (blockIdx.x));
188+
}
189+
185190
__CG_STATIC_QUALIFIER__ bool is_valid() { return static_cast<bool>(__ockl_grid_is_valid()); }
186191

187192
__CG_STATIC_QUALIFIER__ void sync() { __ockl_grid_sync(); }

projects/hip-tests/catch/include/cpu_grid.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ struct CPUGrid {
4343
return thread_rank_in_grid % threads_in_block_count_;
4444
}
4545

46+
inline std::optional<unsigned int> block_rank_in_grid(
47+
const unsigned int thread_rank_in_grid) const {
48+
if (thread_rank_in_grid > thread_count_) {
49+
return std::nullopt;
50+
}
51+
52+
return thread_rank_in_grid / threads_in_block_count_;
53+
}
54+
4655
inline std::optional<dim3> block_idx(const unsigned int thread_rank_in_grid) const {
4756
if (thread_rank_in_grid > thread_count_) {
4857
return std::nullopt;

projects/hip-tests/catch/unit/cooperativeGrps/grid_group.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ static __global__ void grid_group_thread_rank_getter(unsigned int* thread_ranks)
3939
thread_ranks[thread_rank_in_grid()] = cg::this_grid().thread_rank();
4040
}
4141

42+
static __global__ void grid_group_block_rank_getter(unsigned int* block_ranks) {
43+
block_ranks[thread_rank_in_grid()] = cg::this_grid().block_rank();
44+
}
45+
4246
static __global__ void grid_group_is_valid_getter(unsigned int* is_valid_flags) {
4347
is_valid_flags[thread_rank_in_grid()] = cg::this_grid().is_valid();
4448
}
@@ -160,9 +164,18 @@ TEST_CASE("Unit_Grid_Group_Getters_Positive_Basic") {
160164
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
161165
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
162166
HIP_CHECK(hipDeviceSynchronize());
167+
HIP_CHECK(hipLaunchCooperativeKernel(grid_group_block_rank_getter, blocks, threads, params, 0, 0));
163168

164169
// Verify grid_group.is_valid() values
165170
ArrayAllOf(uint_arr.ptr(), grid.thread_count_, [](uint32_t) { return 1; });
171+
172+
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
173+
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
174+
HIP_CHECK(hipDeviceSynchronize());
175+
176+
// Verify grid_group.block_rank() values
177+
ArrayAllOf(uint_arr.ptr(), grid.thread_count_, [threads](uint32_t i) {
178+
return i/(threads.x * threads.y * threads.z); });
166179
}
167180

168181
/**

projects/hip-tests/catch/unit/cooperativeGrps/thread_block.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ static __global__ void thread_block_thread_rank_getter(unsigned int* thread_rank
4848
thread_ranks[thread_rank_in_grid()] = group.thread_rank();
4949
}
5050

51+
template <typename BaseType = cg::thread_block>
52+
static __global__ void thread_block_block_rank_getter(unsigned int* block_ranks) {
53+
const BaseType group = cg::this_thread_block();
54+
block_ranks[thread_rank_in_grid()] = group.block_rank();
55+
}
56+
5157
static __global__ void thread_block_group_indices_getter(dim3* group_indices) {
5258
group_indices[thread_rank_in_grid()] = cg::this_thread_block().group_index();
5359
}
@@ -110,10 +116,20 @@ TEST_CASE("Unit_Thread_Block_Getters_Positive_Basic") {
110116
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
111117
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
112118
HIP_CHECK(hipDeviceSynchronize());
119+
thread_block_block_rank_getter<<<blocks, threads>>>(uint_arr_dev.ptr());
120+
HIP_CHECK(hipGetLastError());
113121

114122
// Validate thread_block.thread_rank() values
115123
ArrayAllOf(uint_arr.ptr(), grid.thread_count_,
116124
[&grid](uint32_t i) { return grid.thread_rank_in_block(i).value(); });
125+
126+
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
127+
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
128+
HIP_CHECK(hipDeviceSynchronize());
129+
130+
// Validate thread_block.block_rank() values
131+
ArrayAllOf(uint_arr.ptr(), grid.thread_count_,
132+
[&grid](uint32_t i) { return grid.block_rank_in_grid(i).value(); });
117133
}
118134

119135
{

0 commit comments

Comments
 (0)