Skip to content

Commit 4334c15

Browse files
committed
More fixes to iterators defined in tests
1 parent 506b362 commit 4334c15

File tree

3 files changed

+40
-26
lines changed

3 files changed

+40
-26
lines changed

c/parallel/test/test_merge_sort.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,16 @@ C2H_TEST("DeviceMergeSort::SortKeys works with output iterators", "[merge_sort]"
432432
make_iterator<TestType, random_access_iterator_state_t>(
433433
{"random_access_iterator_state_t", "struct random_access_iterator_state_t { int* d_input; };\n"},
434434
{"advance",
435-
"extern \"C\" __device__ void advance(random_access_iterator_state_t* state, unsigned long long offset) {\n"
436-
" state->d_input += offset;\n"
435+
"extern \"C\" __device__ void advance(void* state, const void* offset) {\n"
436+
" auto* typed_state = static_cast<random_access_iterator_state_t*>(state);\n"
437+
" auto offset_val = *static_cast<const unsigned long long*>(offset);\n"
438+
" typed_state->d_input += offset_val;\n"
437439
"}"},
438440
{"dereference",
439-
"extern \"C\" __device__ void dereference(random_access_iterator_state_t* state, int x) {\n"
440-
" *state->d_input = x;\n"
441+
"extern \"C\" __device__ void dereference(void* state, const void* x) {\n"
442+
" auto* typed_state = static_cast<random_access_iterator_state_t*>(state);\n"
443+
" auto x_val = *static_cast<const int*>(x);\n"
444+
" *typed_state->d_input = x_val;\n"
441445
"}"});
442446
std::vector<TestType> input_keys = make_shuffled_key_ranks_vector<TestType>(num_items);
443447
std::vector<TestType> expected_keys = input_keys;
@@ -471,13 +475,16 @@ C2H_TEST("DeviceMergeSort::SortPairs works with output iterators for items", "[m
471475
make_iterator<TestType, item_random_access_iterator_state_t>(
472476
"struct item_random_access_iterator_state_t { int* d_input; };\n",
473477
{"advance",
474-
"extern \"C\" __device__ void advance(item_random_access_iterator_state_t* state, unsigned long long offset) "
475-
"{\n"
476-
" state->d_input += offset;\n"
478+
"extern \"C\" __device__ void advance(void* state, const void* offset) {\n"
479+
" auto* typed_state = static_cast<item_random_access_iterator_state_t*>(state);\n"
480+
" auto offset_val = *static_cast<const unsigned long long*>(offset);\n"
481+
" typed_state->d_input += offset_val;\n"
477482
"}"},
478483
{"dereference",
479-
"extern \"C\" __device__ void dereference(item_random_access_iterator_state_t* state, int x) {\n"
480-
" *state->d_input = x;\n"
484+
"extern \"C\" __device__ void dereference(void* state, const void* x) {\n"
485+
" auto* typed_state = static_cast<item_random_access_iterator_state_t*>(state);\n"
486+
" auto x_val = *static_cast<const int*>(x);\n"
487+
" *typed_state->d_input = x_val;\n"
481488
"}"});
482489

483490
pointer_t<TestType> input_keys_it(input_keys);

c/parallel/test/test_segmented_reduce.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,20 +443,23 @@ struct {0} {{
443443
/* 2 */ index_type_name);
444444

445445
static constexpr std::string_view it_advance_fn_def_src_tmpl = R"XXX(
446-
extern "C" __device__ void {0}({1}* state, {2} offset)
446+
extern "C" __device__ void {0}(void* state, const void* offset)
447447
{{
448-
state->linear_id += offset;
448+
auto* typed_state = static_cast<{1}*>(state);
449+
auto offset_val = *static_cast<const {2}*>(offset);
450+
typed_state->linear_id += offset_val;
449451
}}
450452
)XXX";
451453

452454
const std::string it_advance_fn_def_src =
453455
std::format(it_advance_fn_def_src_tmpl, /*0*/ advance_fn_name, state_name, index_type_name);
454456

455457
static constexpr std::string_view it_dereference_fn_src_tmpl = R"XXX(
456-
extern "C" __device__ void {0}({2} *state, {1}* result) {{
457-
unsigned long long col_id = (state->linear_id) / (state->n_rows);
458-
unsigned long long row_id = (state->linear_id) - col_id * (state->n_rows);
459-
*result = *(state->ptr + row_id * (state->n_cols) + col_id);
458+
extern "C" __device__ void {0}(const void* state, {1}* result) {{
459+
auto* typed_state = static_cast<const {2}*>(state);
460+
unsigned long long col_id = (typed_state->linear_id) / (typed_state->n_rows);
461+
unsigned long long row_id = (typed_state->linear_id) - col_id * (typed_state->n_rows);
462+
*result = *(typed_state->ptr + row_id * (typed_state->n_cols) + col_id);
460463
}}
461464
)XXX";
462465

c/parallel/test/test_util.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -955,8 +955,8 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
955955
else
956956
{
957957
dereference_fn_def_src = std::format(
958-
"extern \"C\" __device__ void {0}(const void* state, const void* x) {{\n"
959-
" auto* typed_state = static_cast<const {1}*>(state);\n"
958+
"extern \"C\" __device__ void {0}(void* state, const void* x) {{\n"
959+
" auto* typed_state = static_cast<{1}*>(state);\n"
960960
" auto x_val = *static_cast<const {2}*>(x);\n"
961961
" *typed_state->data = x_val{3};\n"
962962
"}}",
@@ -1005,8 +1005,9 @@ inline std::tuple<std::string, std::string, std::string> make_counting_iterator_
10051005
iterator_state_name);
10061006

10071007
std::string dereference_fn_def_src = std::format(
1008-
"extern \"C\" __device__ void {0}({1}* state, {2}* result) {{ \n"
1009-
" *result = state->value;\n"
1008+
"extern \"C\" __device__ void {0}(const void* state, {2}* result) {{ \n"
1009+
" auto* typed_state = static_cast<const {1}*>(state);\n"
1010+
" *result = typed_state->value;\n"
10101011
"}}",
10111012
dereference_fn_name,
10121013
iterator_state_name,
@@ -1093,8 +1094,9 @@ inline std::tuple<std::string, std::string, std::string> make_reverse_iterator_s
10931094
if (kind == iterator_kind::INPUT)
10941095
{
10951096
dereference_fn_src = std::format(
1096-
"extern \"C\" __device__ void {0}({1}* state, {2}* result) {{\n"
1097-
" *result = (*state->data){3};\n"
1097+
"extern \"C\" __device__ void {0}(const void* state, {2}* result) {{\n"
1098+
" auto* typed_state = static_cast<const {1}*>(state);\n"
1099+
" *result = (*typed_state->data){3};\n"
10981100
"}}",
10991101
dereference_fn_name,
11001102
iterator_state_name,
@@ -1104,8 +1106,10 @@ inline std::tuple<std::string, std::string, std::string> make_reverse_iterator_s
11041106
else
11051107
{
11061108
dereference_fn_src = std::format(
1107-
"extern \"C\" __device__ void {0}({1}* state, {2} x) {{\n"
1108-
" *state->data = x{3};\n"
1109+
"extern \"C\" __device__ void {0}(void* state, const void* x) {{\n"
1110+
" auto* typed_state = static_cast<{1}*>(state);\n"
1111+
" auto x_val = *static_cast<const {2}*>(x);\n"
1112+
" *typed_state->data = x_val{3};\n"
11091113
"}}",
11101114
dereference_fn_name,
11111115
iterator_state_name,
@@ -1261,7 +1265,7 @@ extern "C" __device__ void {0}(const void* transform_it_state, {2}* result) {{
12611265
{7} base_result;
12621266
{4}(&(typed_state->base_it_state), &base_result);
12631267
*result = {3}(
1264-
&(typed_state->functor_state),
1268+
const_cast<decltype(typed_state->functor_state)*>(&(typed_state->functor_state)),
12651269
base_result
12661270
);
12671271
}}
@@ -1434,7 +1438,7 @@ inline std::tuple<std::string, std::string, std::string> make_discard_iterator_s
14341438
if (kind == iterator_kind::INPUT)
14351439
{
14361440
dereference_fn_def_src = std::format(
1437-
"extern \"C\" __device__ void {0}({1}* /*state*/, {2}* /*result*/) {{\n"
1441+
"extern \"C\" __device__ void {0}(const void* /*state*/, {2}* /*result*/) {{\n"
14381442
"}}",
14391443
dereference_fn_name,
14401444
iterator_state_name,
@@ -1443,7 +1447,7 @@ inline std::tuple<std::string, std::string, std::string> make_discard_iterator_s
14431447
else
14441448
{
14451449
dereference_fn_def_src = std::format(
1446-
"extern \"C\" __device__ void {0}({1}* /*state*/, {2} /*x*/) {{\n"
1450+
"extern \"C\" __device__ void {0}(void* /*state*/, const void* /*x*/) {{\n"
14471451
"}}",
14481452
dereference_fn_name,
14491453
iterator_state_name,

0 commit comments

Comments
 (0)