@@ -244,16 +244,15 @@ struct item_pair
244244struct DeviceMergeSort_SortPairsCopy_CustomType_Fixture_Tag ;
245245C2H_TEST (" DeviceMergeSort:SortPairsCopy works with custom types" , " [merge_sort]" )
246246{
247- const size_t num_items = GENERATE_COPY (take (2 , random (1 , 100000 )), values ({5 , 10000 , 100000 }));
248- operation_t op = make_operation (
249- " op" ,
250- " struct key_pair { short a; size_t b; };\n "
251- " extern \" C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n "
252- " key_pair* lhs = static_cast<key_pair*>(lhs_ptr);\n "
253- " key_pair* rhs = static_cast<key_pair*>(rhs_ptr);\n "
254- " bool* out = static_cast<bool*>(out_ptr);\n "
255- " *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;\n "
256- " }" );
247+ const size_t num_items = GENERATE_COPY (take (2 , random (1 , 100000 )), values ({5 , 10000 , 100000 }));
248+ operation_t op = make_operation (" op" ,
249+ R"( struct key_pair { short a; size_t b; };
250+ extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {
251+ key_pair* lhs = static_cast<key_pair*>(lhs_ptr);
252+ key_pair* rhs = static_cast<key_pair*>(rhs_ptr);
253+ bool* out = static_cast<bool*>(out_ptr);
254+ *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;
255+ })" );
257256 const std::vector<short > a = generate<short >(num_items);
258257 const std::vector<size_t > b = generate<size_t >(num_items);
259258 std::vector<key_pair> input_keys (num_items);
@@ -301,16 +300,15 @@ C2H_TEST("DeviceMergeSort:SortPairsCopy works with custom types", "[merge_sort]"
301300struct DeviceMergeSort_SortPairsCopy_CustomType_WellKnown_Fixture_Tag ;
302301C2H_TEST (" DeviceMergeSort:SortPairsCopy works with custom types with well-known predicates" , " [merge_sort][well_known]" )
303302{
304- const size_t num_items = GENERATE_COPY (take (2 , random (1 , 100000 )), values ({5 , 10000 , 100000 }));
305- operation_t op_state = make_operation (
306- " op" ,
307- " struct key_pair { short a; size_t b; };\n "
308- " extern \" C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n "
309- " key_pair* lhs = static_cast<key_pair*>(lhs_ptr);\n "
310- " key_pair* rhs = static_cast<key_pair*>(rhs_ptr);\n "
311- " bool* out = static_cast<bool*>(out_ptr);\n "
312- " *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;\n "
313- " }" );
303+ const size_t num_items = GENERATE_COPY (take (2 , random (1 , 100000 )), values ({5 , 10000 , 100000 }));
304+ operation_t op_state = make_operation (" op" ,
305+ R"( struct key_pair { short a; size_t b; };
306+ extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {
307+ key_pair* lhs = static_cast<key_pair*>(lhs_ptr);
308+ key_pair* rhs = static_cast<key_pair*>(rhs_ptr);
309+ bool* out = static_cast<bool*>(out_ptr);
310+ *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;
311+ })" );
314312 cccl_op_t op = op_state;
315313 op.type = cccl_op_kind_t ::CCCL_LESS;
316314 const std::vector<short > a = generate<short >(num_items);
@@ -432,17 +430,17 @@ C2H_TEST("DeviceMergeSort::SortKeys works with output iterators", "[merge_sort]"
432430 make_iterator<TestType, random_access_iterator_state_t >(
433431 {" random_access_iterator_state_t" , " struct random_access_iterator_state_t { int* d_input; };\n " },
434432 {" advance" ,
435- " extern \" C \ " __device__ void advance(void* state, const void* offset) {\n "
436- " auto* typed_state = static_cast<random_access_iterator_state_t*>(state);\n "
437- " auto offset_val = *static_cast<const unsigned long long*>(offset);\n "
438- " typed_state->d_input += offset_val;\n "
439- " } " },
433+ R"( extern "C " __device__ void advance(void* state, const void* offset) {
434+ auto* typed_state = static_cast<random_access_iterator_state_t*>(state);
435+ auto offset_val = *static_cast<const unsigned long long*>(offset);
436+ typed_state->d_input += offset_val;
437+ } ) " },
440438 {" dereference" ,
441- " extern \" C \ " __device__ void dereference(void* state, const void* x) {\n "
442- " auto* typed_state = static_cast<random_access_iterator_state_t*>(state);\n "
443- " auto x_val = *static_cast<const int*>(x);\n "
444- " *typed_state->d_input = x_val;\n "
445- " } " });
439+ R"( extern "C " __device__ void dereference(void* state, const void* x) {
440+ auto* typed_state = static_cast<random_access_iterator_state_t*>(state);
441+ auto x_val = *static_cast<const int*>(x);
442+ *typed_state->d_input = x_val;
443+ } ) " });
446444 std::vector<TestType> input_keys = make_shuffled_key_ranks_vector<TestType>(num_items);
447445 std::vector<TestType> expected_keys = input_keys;
448446
@@ -475,17 +473,17 @@ C2H_TEST("DeviceMergeSort::SortPairs works with output iterators for items", "[m
475473 make_iterator<TestType, item_random_access_iterator_state_t >(
476474 " struct item_random_access_iterator_state_t { int* d_input; };\n " ,
477475 {" advance" ,
478- " extern \" C \ " __device__ void advance(void* state, const void* offset) {\n "
479- " auto* typed_state = static_cast<item_random_access_iterator_state_t*>(state);\n "
480- " auto offset_val = *static_cast<const unsigned long long*>(offset);\n "
481- " typed_state->d_input += offset_val;\n "
482- " } " },
476+ R"( extern "C " __device__ void advance(void* state, const void* offset) {
477+ auto* typed_state = static_cast<item_random_access_iterator_state_t*>(state);
478+ auto offset_val = *static_cast<const unsigned long long*>(offset);
479+ typed_state->d_input += offset_val;
480+ } ) " },
483481 {" dereference" ,
484- " extern \" C \ " __device__ void dereference(void* state, const void* x) {\n "
485- " auto* typed_state = static_cast<item_random_access_iterator_state_t*>(state);\n "
486- " auto x_val = *static_cast<const int*>(x);\n "
487- " *typed_state->d_input = x_val;\n "
488- " } " });
482+ R"( extern "C " __device__ void dereference(void* state, const void* x) {
483+ auto* typed_state = static_cast<item_random_access_iterator_state_t*>(state);
484+ auto x_val = *static_cast<const int*>(x);
485+ *typed_state->d_input = x_val;
486+ } ) " });
489487
490488 pointer_t <TestType> input_keys_it (input_keys);
491489 pointer_t <item_t > input_items_it (input_items);
@@ -657,12 +655,12 @@ C2H_TEST("MergeSort works with C++ source operations using custom headers", "[me
657655/* C2H_TEST("DeviceMergeSort:SortPairsCopy fails to build for large types due to no vsmem", "[merge_sort]")
658656{
659657 const size_t num_items = 1;
660- operation_t op = make_operation(
658+ operation_t op = make_operation(
661659 "op",
662- " struct large_key_pair { int a; char c[100]; };\n"
663- " extern \"C\ " __device__ bool op(large_key_pair lhs, large_key_pair rhs) {\n"
664- " return lhs.a < rhs.a;\n"
665- "} ");
660+ R"( struct large_key_pair { int a; char c[100]; };
661+ extern "C " __device__ bool op(large_key_pair lhs, large_key_pair rhs) {
662+ return lhs.a < rhs.a;
663+ }) ");
666664 const std::vector<int> a = generate<int>(num_items);
667665 std::vector<large_key_pair> input_keys(num_items);
668666 for (std::size_t i = 0; i < num_items; ++i)
0 commit comments