Skip to content

Commit 1238f28

Browse files
charan-003Sai Charanbernhardmgruber
authored
add parallel scan support for TBB and OMP (#6178)
Co-authored-by: Sai Charan <[email protected]> Co-authored-by: Bernhard Manfred Gruber <[email protected]>
1 parent c9c9a0d commit 1238f28

File tree

4 files changed

+228
-4
lines changed

4 files changed

+228
-4
lines changed

thrust/thrust/system/omp/detail/scan.h

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,48 @@
2626
# pragma system_header
2727
#endif // no system header
2828

29-
// this system inherits scan
30-
#include <thrust/system/cpp/detail/scan.h>
29+
// OMP parallel scan implementation
30+
#include <thrust/system/omp/detail/execution_policy.h>
31+
32+
THRUST_NAMESPACE_BEGIN
33+
namespace system::omp::detail
34+
{
35+
36+
template <typename DerivedPolicy, typename InputIterator, typename OutputIterator, typename BinaryFunction>
37+
OutputIterator inclusive_scan(
38+
execution_policy<DerivedPolicy>& exec,
39+
InputIterator first,
40+
InputIterator last,
41+
OutputIterator result,
42+
BinaryFunction binary_op);
43+
44+
template <typename DerivedPolicy,
45+
typename InputIterator,
46+
typename OutputIterator,
47+
typename InitialValueType,
48+
typename BinaryFunction>
49+
OutputIterator inclusive_scan(
50+
execution_policy<DerivedPolicy>& exec,
51+
InputIterator first,
52+
InputIterator last,
53+
OutputIterator result,
54+
InitialValueType init,
55+
BinaryFunction binary_op);
56+
57+
template <typename DerivedPolicy,
58+
typename InputIterator,
59+
typename OutputIterator,
60+
typename InitialValueType,
61+
typename BinaryFunction>
62+
OutputIterator exclusive_scan(
63+
execution_policy<DerivedPolicy>& exec,
64+
InputIterator first,
65+
InputIterator last,
66+
OutputIterator result,
67+
InitialValueType init,
68+
BinaryFunction binary_op);
69+
70+
} // namespace system::omp::detail
71+
THRUST_NAMESPACE_END
72+
73+
#include <thrust/system/omp/detail/scan.inl>
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#pragma once
5+
6+
#include <thrust/detail/config.h>
7+
8+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
9+
# pragma GCC system_header
10+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
11+
# pragma clang system_header
12+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
13+
# pragma system_header
14+
#endif // no system header
15+
16+
#include <thrust/advance.h>
17+
#include <thrust/detail/function.h>
18+
#include <thrust/detail/temporary_array.h>
19+
#include <thrust/distance.h>
20+
#include <thrust/iterator/iterator_traits.h>
21+
#include <thrust/system/omp/detail/pragma_omp.h>
22+
#include <thrust/system/omp/detail/scan.h>
23+
24+
#include <cuda/std/__functional/invoke.h>
25+
#include <cuda/std/__numeric/exclusive_scan.h>
26+
#include <cuda/std/__numeric/inclusive_scan.h>
27+
#include <cuda/std/__numeric/reduce.h>
28+
#include <cuda/std/cmath>
29+
30+
#include <omp.h>
31+
32+
THRUST_NAMESPACE_BEGIN
33+
namespace system::omp::detail
34+
{
35+
36+
// Threshold below which serial scan is faster than parallel
37+
// Benchmarking shows parallel overhead dominates for small arrays
38+
inline constexpr size_t parallel_scan_threshold = 1024;
39+
40+
template <bool IsInclusive,
41+
typename DerivedPolicy,
42+
typename InputIterator,
43+
typename OutputIterator,
44+
typename InitialValueType,
45+
typename BinaryFunction>
46+
OutputIterator scan_impl(
47+
execution_policy<DerivedPolicy>& exec,
48+
InputIterator first,
49+
InputIterator last,
50+
OutputIterator result,
51+
InitialValueType init,
52+
BinaryFunction binary_op)
53+
{
54+
using namespace thrust::detail;
55+
56+
using ValueType =
57+
typename ::cuda::std::__accumulator_t<BinaryFunction, thrust::detail::it_value_t<InputIterator>, InitialValueType>;
58+
using Size = thrust::detail::it_difference_t<InputIterator>;
59+
60+
const Size n = ::cuda::std::distance(first, last);
61+
62+
if (n == 0)
63+
{
64+
return result;
65+
}
66+
67+
const int num_threads = omp_get_max_threads();
68+
69+
// Use serial scan for small arrays where parallel overhead dominates
70+
if (static_cast<size_t>(n) < parallel_scan_threshold || num_threads <= 1)
71+
{
72+
if constexpr (IsInclusive)
73+
{
74+
::cuda::std::inclusive_scan(first, last, result, binary_op, init);
75+
}
76+
else
77+
{
78+
::cuda::std::exclusive_scan(first, last, result, init, binary_op);
79+
}
80+
return result;
81+
}
82+
83+
thrust::detail::temporary_array<ValueType, DerivedPolicy> block_sums(exec, num_threads);
84+
85+
// Step 1: Reduce each block (N reads)
86+
THRUST_PRAGMA_OMP(parallel num_threads(num_threads))
87+
{
88+
const int tid = omp_get_thread_num();
89+
const Size block_size = ::cuda::ceil_div(n, num_threads);
90+
const Size start = tid * block_size;
91+
const Size end = ::cuda::std::min(start + block_size, n);
92+
93+
if (start < n)
94+
{
95+
block_sums[tid] = ::cuda::std::reduce(first + start, first + end, tid == 0 ? init : ValueType{}, binary_op);
96+
}
97+
}
98+
99+
// Step 2: Scan block sums using cuda::std::exclusive_scan
100+
::cuda::std::exclusive_scan(block_sums.begin(), block_sums.end(), block_sums.begin(), ValueType{}, binary_op);
101+
102+
// Step 3: Scan each block with offset (N reads/writes)
103+
THRUST_PRAGMA_OMP(parallel num_threads(num_threads))
104+
{
105+
const int tid = omp_get_thread_num();
106+
const Size block_size = ::cuda::ceil_div(n, num_threads);
107+
const Size start = tid * block_size;
108+
const Size end = ::cuda::std::min(start + block_size, n);
109+
110+
if (start < n)
111+
{
112+
const ValueType offset = block_sums[tid];
113+
if constexpr (IsInclusive)
114+
{
115+
::cuda::std::inclusive_scan(first + start, first + end, result + start, binary_op, offset);
116+
}
117+
else
118+
{
119+
::cuda::std::exclusive_scan(first + start, first + end, result + start, offset, binary_op);
120+
}
121+
}
122+
}
123+
124+
return result + n;
125+
}
126+
127+
template <typename DerivedPolicy, typename InputIterator, typename OutputIterator, typename BinaryFunction>
128+
OutputIterator inclusive_scan(
129+
execution_policy<DerivedPolicy>& exec,
130+
InputIterator first,
131+
InputIterator last,
132+
OutputIterator result,
133+
BinaryFunction binary_op)
134+
{
135+
using ValueType = thrust::detail::it_value_t<InputIterator>;
136+
return inclusive_scan(exec, first, last, result, ValueType{}, binary_op);
137+
}
138+
139+
template <typename DerivedPolicy,
140+
typename InputIterator,
141+
typename OutputIterator,
142+
typename InitialValueType,
143+
typename BinaryFunction>
144+
OutputIterator inclusive_scan(
145+
execution_policy<DerivedPolicy>& exec,
146+
InputIterator first,
147+
InputIterator last,
148+
OutputIterator result,
149+
InitialValueType init,
150+
BinaryFunction binary_op)
151+
{
152+
return scan_impl<true>(exec, first, last, result, init, binary_op);
153+
}
154+
155+
template <typename DerivedPolicy,
156+
typename InputIterator,
157+
typename OutputIterator,
158+
typename InitialValueType,
159+
typename BinaryFunction>
160+
OutputIterator exclusive_scan(
161+
execution_policy<DerivedPolicy>& exec,
162+
InputIterator first,
163+
InputIterator last,
164+
OutputIterator result,
165+
InitialValueType init,
166+
BinaryFunction binary_op)
167+
{
168+
return scan_impl<false>(exec, first, last, result, init, binary_op);
169+
}
170+
171+
} // namespace system::omp::detail
172+
THRUST_NAMESPACE_END

thrust/thrust/system/omp/detail/scan_by_key.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@
2828

2929
// this system inherits this algorithm
3030
#include <thrust/system/cpp/detail/scan_by_key.h>
31+
32+
// Ensure OMP scan is available before using generic scan_by_key
33+
#include <thrust/system/omp/detail/scan.h>
34+
35+
// use generic parallel implementation
36+
#include <thrust/system/detail/generic/scan_by_key.h>

thrust/thrust/system/tbb/detail/scan_by_key.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,8 @@
2626
# pragma system_header
2727
#endif // no system header
2828

29-
// this system inherits scan_by_key
30-
#include <thrust/system/cpp/detail/scan_by_key.h>
29+
// Ensure TBB scan is available before using generic scan_by_key
30+
#include <thrust/system/tbb/detail/scan.h>
31+
32+
// use generic parallel implementation
33+
#include <thrust/system/detail/generic/scan_by_key.h>

0 commit comments

Comments
 (0)