1010#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
1111#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1212
13+ #include "nbl/builtin/hlsl/algorithm.hlsl"
1314#include "nbl/builtin/hlsl/functional.hlsl"
1415#include "nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl"
1516
@@ -138,45 +139,52 @@ SPECIALIZE_ALL(maximum,Max);
138139#undef SPECIALIZE_ALL
139140#undef SPECIALIZE
140141
142+ template<class BinOp>
143+ struct inclusive_scan_impl
144+ {
145+ using scalar_t = typename BinOp::type_t;
146+
147+ static inclusive_scan_impl<BinOp> create (scalar_t _value)
148+ {
149+ inclusive_scan_impl<BinOp> retval;
150+ retval.value = _value;
151+ retval.subgroupInvocation = glsl::gl_SubgroupInvocationID ();
152+ return retval;
153+ }
154+
155+ template<uint16_t StepLog2>
156+ void __call ()
157+ {
158+ BinOp op;
159+ const uint32_t step = 1u << StepLog2;
160+ spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
161+ scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
162+ value = op (value, hlsl::mix (rhs, BinOp::identity, subgroupInvocation < step));
163+ }
164+
165+ scalar_t value;
166+ uint32_t subgroupInvocation;
167+ };
168+
141169// specialize portability
142170template<class Params, class BinOp>
143171struct inclusive_scan<Params, BinOp, 1 , false >
144172{
145173 using type_t = typename Params::type_t;
146174 using scalar_t = typename Params::scalar_t;
147175 using binop_t = typename Params::binop_t;
148- // assert T == scalar type, binop::type == T
149176 using config_t = typename Params::config_t;
150177
151- // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
152- // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
153-
154178 scalar_t operator ()(scalar_t value)
155179 {
156180 return __call (value);
157181 }
158182
159183 static scalar_t __call (scalar_t value)
160184 {
161- // sync up each subgroup invocation so it runs in lockstep
162- // not ideal because might not write to shared memory but a storage class is needed
163- spirv::memoryBarrier (spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
164-
165- binop_t op;
166- const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
167-
168- scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
169- value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < 1u));
170-
171- const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
172- [unroll]
173- for (uint32_t i = 1 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
174- {
175- const uint32_t step = 1u << i;
176- rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
177- value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < step));
178- }
179- return value;
185+ inclusive_scan_impl<binop_t> f_impl = inclusive_scan_impl<binop_t>::create (value);
186+ unrolled_for_range<0 , config_t::SizeLog2>::template __call<inclusive_scan_impl<binop_t> >(f_impl);
187+ return f_impl.value;
180188 }
181189};
182190
@@ -190,14 +198,36 @@ struct exclusive_scan<Params, BinOp, 1, false>
190198 scalar_t operator ()(scalar_t value)
191199 {
192200 // sync up each subgroup invocation so it runs in lockstep
193- // not ideal because might not write to shared memory but a storage class is needed
194- spirv::memoryBarrier (spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
201+ spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
195202
196203 scalar_t left = hlsl::mix (binop_t::identity, glsl::subgroupShuffleUp<scalar_t>(value,1 ), bool (glsl::gl_SubgroupInvocationID ()));
197204 return inclusive_scan<Params, BinOp, 1 , false >::__call (left);
198205 }
199206};
200207
208+ template<class BinOp>
209+ struct reduction_impl
210+ {
211+ using scalar_t = typename BinOp::type_t;
212+
213+ static reduction_impl<BinOp> create (scalar_t _value)
214+ {
215+ reduction_impl<BinOp> retval;
216+ retval.value = _value;
217+ return retval;
218+ }
219+
220+ template<uint16_t StepLog2>
221+ void __call ()
222+ {
223+ BinOp op;
224+ spirv::controlBarrier (spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsMaskNone);
225+ value = op (glsl::subgroupShuffleXor<scalar_t>(value, 0x1u<<StepLog2),value);
226+ }
227+
228+ scalar_t value;
229+ };
230+
201231template<class Params, class BinOp>
202232struct reduction<Params, BinOp, 1 , false >
203233{
@@ -206,22 +236,11 @@ struct reduction<Params, BinOp, 1, false>
206236 using binop_t = typename Params::binop_t;
207237 using config_t = typename Params::config_t;
208238
209- // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
210- // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
211-
212239 scalar_t operator ()(scalar_t value)
213240 {
214- // sync up each subgroup invocation so it runs in lockstep
215- // not ideal because might not write to shared memory but a storage class is needed
216- spirv::memoryBarrier (spv::ScopeSubgroup, spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireMask);
217-
218- binop_t op;
219- const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
220- [unroll]
221- for (uint32_t i = 0 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
222- value = op (glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
223-
224- return value;
241+ reduction_impl<binop_t> f_impl = reduction_impl<binop_t>::create (value);
242+ unrolled_for_range<0 , config_t::SizeLog2>::template __call<reduction_impl<binop_t> >(f_impl);
243+ return f_impl.value;
225244 }
226245};
227246
0 commit comments