diff --git a/Project.toml b/Project.toml index 2344fc3..f5746ab 100644 --- a/Project.toml +++ b/Project.toml @@ -34,11 +34,13 @@ Transducers = "0.4.35" julia = "1.6" [extensions] +FoldsMetalExt = "Metal" FoldsOnlineStatsBaseExt = "OnlineStatsBase" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -46,4 +48,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Aqua", "Documenter", "Test"] [weakdeps] +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" diff --git a/ext/FoldsMetalExt/FoldsMetalExt.jl b/ext/FoldsMetalExt/FoldsMetalExt.jl new file mode 100644 index 0000000..078df16 --- /dev/null +++ b/ext/FoldsMetalExt/FoldsMetalExt.jl @@ -0,0 +1,56 @@ +module FoldsMetalExt + +export MetalEx, CoalescedMetalEx + +using Metal +using Metal: @allowscalar, mtlconvert, mtlfunction +using Core: Typeof +using InitialValues: InitialValue, asmonoid +#using UnionArrays: UnionArrays, UnionVector +using Transducers: + @return_if_reduced, + Executor, + Map, + Reduced, + Transducer, + Transducers, + combine, + complete, + next, + opcompose, + reduced, + start, + transduce, + unreduced + +# TODO: Don't import internals from Transducers: +using Transducers: + AbstractReduction, + DefaultInit, + DefaultInitOf, + EmptyResultError, + IdentityTransducer, + Reduction, + _reducingfunction, + completebasecase, + extract_transducer, + foldl_basecase + +include("utils.jl") +include("kernels.jl") +include("unionvalues.jl") +include("shfl.jl") +include("api.jl") +include("introspection.jl") + +#= Use README as the docstring of the module: +@doc let path = joinpath(dirname(@__DIR__), "README.md") + include_dependency(path) + doc = read(path, String) + doc = replace(doc, r"^```julia"m => "```jldoctest README") + doc = replace(doc, "(https://juliafolds.github.io/FoldsMetalExt.jl/dev/examples/)" => "(@ref examples-toc)") + doc +end FoldsMetalExt +=# + +end diff --git a/ext/FoldsMetalExt/api.jl b/ext/FoldsMetalExt/api.jl new file mode 100644 index 0000000..e12da34 --- /dev/null +++ b/ext/FoldsMetalExt/api.jl @@ -0,0 +1,42 @@ +""" + MetalEx() + +A fold executor implemented using Metal.jl. + +For more information about executor, see +[Transducers.jl's glossary section](https://juliafolds.github.io/Transducers.jl/dev/explanation/glossary/#glossary-executor) +and +[FLoops.jl's API section](https://juliafolds.github.io/FLoops.jl/dev/reference/api/#executor). + +# Examples +```jldoctest +julia> using FoldsMetalExt, Folds + +julia> Folds.sum(1:10, MetalEx()) +55 +``` +""" +struct MetalEx{K} <: Executor + kwargs::K +end + +popsimd(; simd = nothing, kwargs...) = kwargs + +Transducers.transduce(xf, rf::RF, init, xs, exc::MetalEx) where {RF} = + _transduce_metal(xf, rf, init, xs; popsimd(; exc.kwargs...)...) + +Transducers.executor_type(::MtlArray) = MetalEx + +""" + CoalescedMetalEx() + +A fold executor implemented using Metal.jl. It uses coalesced memory access +while supporting non-commutative loops. It can be faster but more limited than +`MetalEx`. +""" +struct CoalescedMetalEx{K} <: Executor + kwargs::K +end + +Transducers.transduce(xf, rf::RF, init, xs, exc::CoalescedMetalEx) where {RF} = + transduce_shfl(xf, rf, init, xs; popsimd(; exc.kwargs...)...) diff --git a/ext/FoldsMetalExt/introspection.jl b/ext/FoldsMetalExt/introspection.jl new file mode 100644 index 0000000..38f999b --- /dev/null +++ b/ext/FoldsMetalExt/introspection.jl @@ -0,0 +1,33 @@ +const RUN_ON_HOST_IF_NORETURN = Ref(false) + +struct FailedInference <: Exception + f::Any + kernel_args::Tuple + kernel_return::Any + host_args::Tuple + host_return::Any +end + +function Base.showerror(io::IO, err::FailedInference) + print(io, FailedInference, ": ") + if err.kernel_return === Union{} + print(io, "Kernel is inferred to not return (return type is `Union{}`)") + else + print(io, "Kernel is inferred to return invalid type: ", err.kernel_return) + end + if err.kernel_return === Union{} && err.host_return !== Union{} + println(io) + print(io, "Note: on the host, the return type is inferred as ", err.host_return) + end + println(io) + printstyled(io, "HINT"; bold = true, color = :light_black) + printstyled( + io, + ": if this exception is caught as `err``, use `Metal.code_typed(err)` to", + " introspect the erronous code."; + color = :light_black, + ) +end + +Metal.code_typed(err::FailedInference; options...) = + Metal.code_typed(err.f, typeof(err.kernel_args); options...) diff --git a/ext/FoldsMetalExt/kernels.jl b/ext/FoldsMetalExt/kernels.jl new file mode 100644 index 0000000..8bd422d --- /dev/null +++ b/ext/FoldsMetalExt/kernels.jl @@ -0,0 +1,369 @@ +_transduce_metal(xf::Transducer, op, init, xs; kwargs...) = + _transduce_metal(xf'(op), init, xs; kwargs...) + +function _transduce_metal(op, init, xs;) + xf0, coll = extract_transducer(xs) + # TODO: more systematic approach to this (and also support product) + if coll isa Iterators.Zip + arrays = coll.is + xf = xf0 + elseif coll isa Iterators.Pairs + arrays = (keys(coll), values(coll)) + xf = xf0 + else + arrays = (coll,) + xf = opcompose(Map(first), xf0) + end + rf = _reducingfunction(xf, op; init = init) + acc = transduce_impl(rf, init, arrays...) + rf_dev = mtlconvert(rf) + if rf_dev === rf + result = complete(rf, acc) + else + result = complete_on_device(rf_dev, acc) + end + if unreduced(result) isa DefaultInitOf + throw(EmptyResultError(rf)) + end + return result +end + +function transduce_impl(rf::F, init, arrays...) where {F} + ys, = (dest, buf) = _transduce!(nothing, rf, init, arrays...) + if buf === nothing + # The accumulator is a singleton. Once we are finished with the + # side-effects of the basecase, transduce is done: + return ys + end + # @info "ys, = _transduce!(nothing, rf, ...)" Text(summary(ys)) + # @info "ys, = _transduce!(nothing, rf, ...)" collect(ys) + length(ys) == 1 && return @allowscalar ys[1] + rf2 = AlwaysCombine(rf) + while true + ys, = _transduce!(buf, rf2, CombineInit(), ys) + # @info "ys, = _transduce!(buf, rf2, ...)" Text(summary(ys)) + # @info "ys, = _transduce!(buf, rf2, ...)" collect(ys) + length(ys) == 1 && return @allowscalar ys[1] + dest, buf = buf, dest + # reusing buffer; is it useful? + end +end + +const _TRUE_ = Ref(true) + +function fake_transduce(rf, xs, init, ::Val{IncludeInit} = Val(false)) where {IncludeInit} + if IncludeInit + if _TRUE_[] + return completebasecase(rf, start(rf, init)) + end + end + if _TRUE_[] + acc1 = next(rf, start(rf, init), first(xs)) + for x in xs + acc1 = next(rf, acc1, x) + end + return completebasecase(rf, acc1) + else + acc1 = fake_transduce(rf, xs, init) + acc2 = fake_transduce(rf, xs, init) + acc3 = completebasecase(rf, acc1) + acc4 = completebasecase(rf, acc2) + acc5 = _combine(rf, acc3, acc4) + return acc5 + end +end + +struct DisallowedElementTypeError{T} <: Exception end +Base.showerror(io::IO, ::DisallowedElementTypeError{T}) where {T} = + print(io, "accumulator type must be `isbits` or `isbitsunion`; got: $T") + +function allocate_buffer(::Type{T}, n) where {T} + if isbitstype(T) + return MtlVector{T}(undef, n) + elseif Base.isbitsunion(T) + error("UnionArrays not supported") + return UnionVector(undef, MtlVector, T, n) + else + # TODO: Fallback to the mutate-or-widen appraoch? (e.g., run first + # iteration on CPU, and then use it as the initial guess of the + # accumulator?) + throw(DisallowedElementTypeError{T}()) + end +end + +Base.@propagate_inbounds getvalues(i) = () +Base.@propagate_inbounds getvalues(i, a) = (a[i],) +Base.@propagate_inbounds getvalues(i, a, as...) = (a[i], getvalues(i, as...)...) + +function _infer_acctype(rf, init, arrays, include_init::Bool = false) + fake_args = ( + mtlconvert(rf), + zip(map(mtlconvert, arrays)...), + mtlconvert(init), + Val(include_init), + ) + fake_args_tt = Tuple{map(Typeof, fake_args)...} + acctype = Metal.return_type(fake_transduce, fake_args_tt) + if acctype === Union{} || !Base.datatype_pointerfree(Some{acctype}) + host_args = (rf, zip(arrays...), init) + acctype_host = Core.Compiler.return_type(fake_transduce, Tuple{map(Typeof, host_args)...}) + if RUN_ON_HOST_IF_NORETURN[] && acctype_host === Union{} + fake_transduce(host_args...) + error("unreachable: incorrect inference") + end + throw(FailedInference(fake_transduce, fake_args, acctype, host_args, acctype_host)) + end + return acctype + # Note: the result of `return_type` is not observable by the caller of the + # API `transduce_impl` +end + +function _transduce!(buf, rf::F, init, arrays...) where {F} + idx = eachindex(arrays...) + n = Int(length(idx)) # e.g., `length(UInt64(0):UInt64(1))` is not an `Int` + + wanted_threads = nextpow(2, n) + compute_threads(max_threads) = + wanted_threads > max_threads ? prevpow(2, max_threads) : wanted_threads + + acctype = if buf === nothing + _infer_acctype(rf, init, arrays) + else + eltype(buf) + end + if !isbitstype(acctype) + error("non-isbits element type not supported") + end + # @show acctype + buf0 = if Base.issingletontype(acctype) + nothing + elseif buf === nothing + # TODO: find a way to compute type for `cufunction` without + # creating a dummy object. + allocate_buffer(acctype, 0) + else + buf + end + args = (buf0, Val(2), rf, init, 0, idx, arrays...) + # global _KARGS = args + kernel_tt = Tuple{map(x -> Typeof(mtlconvert(x)), args)...} + # global KERNEL_TT = kernel_tt + kernel = mtlfunction(transduce_kernel!, kernel_tt) + effelsize = if isbitstype(acctype) + sizeof(acctype) + else + error("UnionArrays not supported") + sizeof(UnionArrays.buffereltypefor(acctype)) + sizeof(UInt8) + end + # @show acctype UnionArrays.buffereltypefor(acctype) effelsize + compute_shmem(threads) = 2 * threads * effelsize + #kernel_config = + # launch_configuration(kernel.fun; shmem = compute_shmem ∘ compute_threads) + #threads = compute_threads(kernel_config.threads) + #shmem = compute_shmem(threads) + #basesize = cld(n, kernel_config.blocks * threads) + threads = min(512, wanted_threads) + groups = cld(n, threads) + basesize = 1 + #@assert blocks <= kernel_config.blocks + + if Base.issingletontype(acctype) + @metal( + threads = threads, + groups = groups, + transduce_kernel!(nothing, Val(2*cld(threads, basesize)), rf, init, basesize, idx, arrays...) + ) + return acctype.instance, nothing + end + + if buf === nothing + dest_buf = allocate_buffer(acctype, groups + cld(groups, threads)) + dest = view(dest_buf, 1:groups) + buf = view(dest_buf, groups+1:length(dest_buf)) + else + dest = view(buf, 1:groups) + end + # @show threads groups basesize init rf idx size(dest) + + # global INVOKE_KERNEL = function () + # @cuda( + # threads = threads, + # blocks = blocks, + # shmem = shmem, + # transduce_kernel!(dest, rf, init, basesize, idx, arrays...) + # ) + # end + + @metal( + threads = threads, + groups = groups, + transduce_kernel!(dest, Val(2*cld(threads, basesize)), rf, init, basesize, idx, arrays...) + ) + + return dest, buf +end + +# Since Metal already requires that everything is inlined, `restack` is not +# useful. Instead, it's better to avoid introducing extra function calls to +# reduce the change that inliner gives up. +@static if isdefined(Transducers, :restack) && isdefined(Metal, Symbol("@device_override")) + Metal.@device_override Transducers.restack(x) = x +end + +function transduce_kernel!( + dest::Union{AbstractArray,Nothing}, + ::Val{SHMEM_SIZE}, + rf::F, + init, + basesize, + idx, + arrays..., +) where {F, SHMEM_SIZE} + tidx = thread_position_in_threadgroup_1d() + gidx = threadgroup_position_in_grid_1d() + isleader = tidx == 1 && gidx == 1 + # Use undef state of `acc` as an "extra Union"; i.e., treat as if the + # initial iteration is unrolled, even though it may not be possible to do so + # for all threads: + local acc + acc_isdefined = false + let n = length(idx), + offset = thread_position_in_grid_1d() - 1, + i1 = offset * basesize + 1, + x1, xf + if i1 <= n + x1 = @inbounds getvalues(idx[i1], arrays...) + @inline getinput(i) = @inbounds getvalues(idx[i], arrays...) + xf = Map(getinput) + acc = foldl_basecase( + Reduction(xf, rf), + next(rf, start(rf, init), x1), + offset*basesize+2:min((offset + 1) * basesize, n), + ) + acc_isdefined = true + end + end + + dest === nothing && return + + # NOTE: Here, `acc` may have a different type for each thread. Since the + # following code contain `sync_threads()`, we cannot introduce any dispatch + # bounary ("function barrier") here. Otherwise, since dispatch is just a + # branch for the GPU, the resulting code tries to synchronize code across + # different branches and hence deadlock. + + # shared mem for a complete reduction + T = eltype(dest) + if isbitstype(T) + shared = MtlThreadGroupArray(T, SHMEM_SIZE) + else + isleader && @mtlprintln("[$gidx:$tidx] non-isbits element type not supported") + error("") + S = UnionArrays.buffereltypefor(T) + data = MtlThreadGroupArray(S, (2 * threadgroups_per_grid_1d(),)) + offset = sizeof(S) * 2 * threadgroups_per_grid_1d() + typeids = MtlThreadGroupArray(UInt8, (2 * threadgroups_per_grid_1d(),), offset) + @assert UInt(pointer(data, length(data) + 1)) == UInt(pointer(typeids)) + shared = UnionVector(T, data, typeids) + end + if acc_isdefined + # Manual union splitting (required for non-type-stable reduction like + # `Folds.sum(last, pairs(xs))`): + @inbounds shared[thread_position_in_threadgroup_1d()] = acc + #= + @manual_union_split( + isbitstype(T), + acc isa UnionArrays.eltypebyid(shared, Val(1)), + acc isa UnionArrays.eltypebyid(shared, Val(2)), + acc isa UnionArrays.eltypebyid(shared, Val(3)), + acc isa UnionArrays.eltypebyid(shared, Val(4)), + acc isa UnionArrays.eltypebyid(shared, Val(5)), + acc isa UnionArrays.eltypebyid(shared, Val(6)), + ) do + @inbounds shared[thread_position_in_threadgroup_1d()] = acc + end + =# + end + + # `iseven(m)` in the `while` loop below enforces that indexing on `shared` + # is in bounds. But, for the last block we need to make sure to combine + # accumulators only within the valid thread indices. + bound = let n = length(idx), + nbasecases = cld(n, basesize), + offsetb = (threadgroup_position_in_grid_1d() - 1) * threads_per_threadgroup_1d() + max(0, nbasecases - offsetb) + end + + m = thread_position_in_threadgroup_1d() - 1 + t = thread_position_in_threadgroup_1d() + s = 1 + c = threads_per_threadgroup_1d() >> 1 + #target = 3 + #tidx == target && @mtlprintln("[$gidx:$tidx] bound = $bound, m = $m, t = $t, s = $s, c = $c") + while c != 0 + threadgroup_barrier(Metal.MemoryFlagThreadGroup) + #tidx == target && @mtlprintln("[$gidx:$tidx] m = $m, t = $t, s = $s, c = $c") + if t + s <= bound && iseven(m) + @inbounds shared[t] = _combine(rf, shared[t], shared[t+s]) + #tidx == 3 && @mtlprintln("[$gidx:$tidx] shared[$t] = $(shared[t]) ($t <- $t + $(t+s))") + m >>= 1 + end + s <<= 1 + c >>= 1 + end + + if t == 1 + @inbounds dest[threadgroup_position_in_grid_1d()] = shared[1] + end + + return +end + +struct CombineInit end + +struct AlwaysCombine{I} <: AbstractReduction{I} + inner::I +end + +#= +AlwaysCombine(rf::Transducers.R_{Map}) = AlwaysCombine(Transducers.inner(rf)) +AlwaysCombine(rf::Transducers.BottomRF) = AlwaysCombine(Transducers.inner(rf)) +=# + +@inline Transducers.start(rf::AlwaysCombine, init) = start(rf.inner, init) +@inline Transducers.start(::AlwaysCombine, init::CombineInit) = init +@inline Transducers.next(::AlwaysCombine, ::CombineInit, input) = first(input) +@inline Transducers.next(rf::F, acc, input) where {F<:AlwaysCombine} = + _combine(rf.inner, acc, first(input)) +@inline Transducers.complete(rf::F, result) where {F<:AlwaysCombine} = + complete(rf.inner, result) +@inline Transducers.combine(rf::F, a, b) where {F<:AlwaysCombine} = _combine(rf.inner, a, b) + +# Semantically correct but inefficient (eager) handling of `Reduced`. +@inline _combine(rf, a::Reduced, b::Reduced) = a +@inline _combine(rf, a::Reduced, b) = a +@inline _combine(rf::RF, a, b::Reduced) where {RF} = reduced(combine(rf, a, unreduced(b))) +@inline _combine(rf::RF, a, b) where {RF} = combine(rf, a, b) + +# TODO: merge this into transduce_kernel! +function complete_kernel!(buf, rf, acc) + buf[1] = complete(rf, acc) + return +end + +function complete_kernel!(rf, acc) + complete(rf, acc) + return +end + +function complete_on_device(rf_dev::RF, acc::ACC) where {RF, ACC} + # global CARGS = (rf_dev, acc) + resulttype = Metal.return_type(complete, Tuple{RF,ACC}) + if Base.issingletontype(resulttype) + @metal complete_kernel!(rf_dev, acc) + return resulttype.instance + end + buf = allocate_buffer(resulttype, 1) + @metal complete_kernel!(buf, rf_dev, acc) + return @allowscalar buf[1] +end diff --git a/ext/FoldsMetalExt/shfl.jl b/ext/FoldsMetalExt/shfl.jl new file mode 100644 index 0000000..02f01e9 --- /dev/null +++ b/ext/FoldsMetalExt/shfl.jl @@ -0,0 +1,310 @@ +transduce_shfl(xf::Transducer, op, init, xs; kwargs...) = + transduce_shfl(xf'(op), init, xs; kwargs...) + +function transduce_shfl(op, init, xs;) + xf0, coll = extract_transducer(xs) + # TODO: more systematic approach to this (and also support product) + if coll isa Iterators.Zip + arrays = coll.is + xf = xf0 + elseif coll isa Iterators.Pairs + arrays = (keys(coll), values(coll)) + xf = xf0 + else + arrays = (coll,) + xf = opcompose(Map(first), xf0) + end + rf = _reducingfunction(xf, op; init = init) + acc = transduce_shfl_impl(rf, init, arrays...) + rf_dev = mtlconvert(rf) + if rf_dev === rf + result = complete(rf, acc) + else + result = complete_on_device(rf_dev, acc) + end + if unreduced(result) isa DefaultInitOf + throw(EmptyResultError(rf)) + end + return result +end + +macro _inbounds(ex) + ex = :($Base.@inbounds $ex) + esc(ex) +end + +function transduce_shfl_impl(rf::F, init, arrays...) where {F} + ys, = (dest, buf) = transduce_shfl!(nothing, rf, init, arrays...) + if buf === nothing + # The accumulator is a singleton. Once we are finished with the + # side-effects of the basecase, transduce is done: + return ys + end + # @info "ys, = transduce_shfl!(nothing, rf, ...)" Text(summary(ys)) + # @info "ys, = transduce_shfl!(nothing, rf, ...)" collect(ys) + length(ys) == 1 && return @allowscalar ys[1] + rf2 = AlwaysCombine(rf) + while true + ys, = transduce_shfl!(buf, rf2, init, ys) + # @info "ys, = transduce_shfl!(buf, rf2, ...)" Text(summary(ys)) + # @info "ys, = transduce_shfl!(buf, rf2, ...)" collect(ys) + length(ys) == 1 && return @allowscalar ys[1] + dest, buf = buf, dest + # reusing buffer; is it useful? + end +end + +@inline function transduce_shfl!(buf, rf, init, arrays...) + idx = eachindex(arrays...) + n = Int(length(idx)) # e.g., `length(UInt64(0):UInt64(1))` is not an `Int` + + #dev = device() + #wsize = warpsize(dev) + wsize = 32 + WARP_SIZE = Val(wsize) + + acctype = if buf === nothing + _infer_acctype(rf, init, arrays, true) + else + eltype(buf) + end + buf0 = if Base.issingletontype(acctype) + nothing + elseif buf === nothing + # TODO: find a way to compute type for `cufunction` without + # creating a dummy object. + allocate_buffer(acctype, 0) + else + buf + end + args = (buf0, Val{acctype}(), WARP_SIZE, rf, init, 0, idx, arrays...) + # global _KARGS = args + kernel_tt = Tuple{map(x -> Typeof(mtlconvert(x)), args)...} + # global KERNEL_TT = kernel_tt + kernel = mtlfunction(transduce_shfl_kernel!, kernel_tt) + effelsize = if isbitstype(acctype) + sizeof(acctype) + else + error("UnionArrays not supported") + sizeof(UnionArrays.buffereltypefor(acctype)) + sizeof(UInt8) + end + # @show acctype UnionArrays.buffereltypefor(acctype) effelsize + #kernel_config = launch_configuration(kernel.fun) + # @show kernel_config + #=threads = let wanted_threads = nextwarp(dev, n) + given_threads = kernel_config.threads + # @show wanted_threads + if wanted_threads > given_threads + prevwarp(dev, given_threads) + else + wanted_threads + end + end=# + threads = 32 + + @assert threads <= wsize * wsize # = 32 * 32 = 1024 + + nwarps_per_block, _nwarps_rem = divrem(threads, wsize) + # @show threads nwarps_per_block _nwarps_rem + @assert _nwarps_rem % wsize == 0 + #basesize = nextwarp(dev, cld(n, kernel_config.blocks * nwarps_per_block)) + basesize = n + groups = cld(n, basesize * nwarps_per_block) + #@assert blocks <= kernel_config.blocks + + # @show threads, blocks, basesize, acctype + if Base.issingletontype(acctype) + @metal( + threads = threads, + groups = groups, + transduce_shfl_kernel!( + nothing, + Val{acctype}(), + WARP_SIZE, + rf, + init, + basesize, + idx, + arrays..., + ) + ) + return acctype.instance, nothing + end + + if buf === nothing + dest_buf = allocate_buffer(acctype, groups + cld(groups, threads)) + dest = view(dest_buf, 1:groups) + buf = view(dest_buf, groups+1:length(dest_buf)) + else + dest = view(buf, 1:groups) + end + + @metal( + threads = threads, + groups = groups, + transduce_shfl_kernel!( + dest, + Val{acctype}(), + WARP_SIZE, + rf, + init, + basesize, + idx, + arrays..., + ) + ) + + return dest, buf +end + +@inline function transduce_shfl_kernel!( + dest::Union{AbstractArray,Nothing}, + ::Val{T}, + ::Val{WARP_SIZE}, + rf::F, + init, + basesize, # length of the main loop per warp + idx, + arrays..., +) where {T,WARP_SIZE,F} + + @inline function _shfl_down(x, delta) + if dest === nothing + return x + end + uv = unionvalue(T, x) + return interpret(simd_shuffle_down(uv, delta)) + end + + nwarps_per_block, _nwarps_rem = divrem(threadgroups_per_grid_1d(), WARP_SIZE) + @assert _nwarps_rem == 0 + warpIdx0, warp_offset = divrem(thread_position_in_threadgroup_1d() - 1, WARP_SIZE) + warpIdx = warpIdx0 + 1 + warp_leader = 1 + warpIdx0 * WARP_SIZE # first thread of this warp + @assert warp_leader + warp_offset == thread_position_in_threadgroup_1d() + + main_offset = warpIdx0 + (threadgroup_position_in_grid_1d() - 1) * nwarps_per_block + main_bound0 = basesize * (main_offset + 1) # bound for this warp in `eachindex(idx)` + need_remainder = main_bound0 > lastindex(idx) # `idx` too short for the last block + main_bound = min(main_bound0, lastindex(idx)) - WARP_SIZE + need_remainder && @assert threadgroup_position_in_grid_1d() == threads_per_grid_1d + + # if warp_offset == 0 + # @cuprintf("%03ld: warpIdx0 = %ld main_offset = %d\n", threadIdx().x, Int(warpIdx0), Int(main_offset)) + # end + + # Main O(N) loop: + acc = start(rf, init) + warp_leader_offset = basesize * main_offset # offset for this warp + while warp_leader_offset <= main_bound + i = warp_leader_offset + warp_offset + 1 + # @cuprintf("%03ld: i = %d\n", threadIdx().x, Int(i)) + acc = @manual_union_split( + acc isa ithtype(T, Val(1)), + acc isa ithtype(T, Val(2)), + acc isa ithtype(T, Val(3)), + acc isa ithtype(T, Val(4)), + acc isa ithtype(T, Val(5)), + acc isa ithtype(T, Val(6)), + ) do + next(rf, acc, @_inbounds getvalues(idx[i], arrays...)) + end + + # Warp-wide merge: + delta = 1 + while delta < WARP_SIZE + acc = _combine(rf, acc, _shfl_down(acc, delta)) + delta <<= 1 + end + if warp_offset != 0 + acc = start(rf, init) + end + warp_leader_offset += WARP_SIZE + end + + # Remainder of the main loop: + if need_remainder + let i = warp_leader_offset + warp_offset + 1 + # @cuprintf("%03ld: (rem) i = %d\n", threadIdx().x, Int(i)) + if i <= lastindex(idx) + acc = @manual_union_split( + acc isa ithtype(T, Val(1)), + acc isa ithtype(T, Val(2)), + acc isa ithtype(T, Val(3)), + acc isa ithtype(T, Val(4)), + acc isa ithtype(T, Val(5)), + acc isa ithtype(T, Val(6)), + ) do + next(rf, acc, @_inbounds getvalues(idx[i], arrays...)) + end + end + + delta = 1 + while delta < WARP_SIZE + acc = _combine(rf, acc, _shfl_down(acc, delta)) + delta <<= 1 + end + end + end + # @cuprintf("%03ld: acc = %f\n", threadIdx().x, acc) + + dest === nothing && return + + # Preparing for block-wide merge: + @assert nwarps_per_block <= 32 + if isbitstype(T) + shared = MtlThreadGroupArray(T, 32) + else + error("UnionArrays not supported") + S = UnionArrays.buffereltypefor(T) + data = MtlThreadGroupArray(S, 32) + typeids = MtlThreadGroupArray(UInt8, 32) + @assert UInt(pointer(data, length(data) + 1)) == UInt(pointer(typeids)) + shared = UnionVector(T, data, typeids) + end + if warp_offset == 0 + @_inbounds shared[warpIdx] = acc + end + + shared_bound = + let n = length(idx), + nbasecases = cld(n, basesize), + offsetb = (threadgroup_position_in_grid_1d() - 1) * nwarps_per_block, + input_bound = nbasecases - offsetb + + min(input_bound, nwarps_per_block) + end + + # Block-wide merge: + shared_delta = 1 + while shared_delta < nwarps_per_block + + # Gather `WARP_SIZE` elements: + i = warp_leader + shared_delta * warp_offset + acc = start(rf, init) + + threadgroup_barrier() + if i <= shared_bound + acc = @_inbounds shared[i] + end + + # Warp-wide merge: + delta = 1 + while delta < WARP_SIZE + acc = _combine(rf, acc, _shfl_down(acc, delta)) + delta <<= 1 + end + + if warp_offset == 0 && thread_position_in_threadgroup_1d() <= lastindex(shared) + @_inbounds shared[thread_position_in_threadgroup_1d()] = acc + end + + shared_delta *= WARP_SIZE + end + + if thread_position_in_threadgroup_1d() == 1 + @_inbounds dest[threadgroup_position_in_grid_1d()] = acc + end + + return +end diff --git a/ext/FoldsMetalExt/unionvalues.jl b/ext/FoldsMetalExt/unionvalues.jl new file mode 100644 index 0000000..81a585e --- /dev/null +++ b/ext/FoldsMetalExt/unionvalues.jl @@ -0,0 +1,97 @@ +const NTypes{N} = NTuple{N, Val} + +valueof(::Val{x}) where {x} = x + +# Not exactly `Base.aligned_sizeof` +Base.@pure function sizeof_aligned(T::Type) + if isbitstype(T) + al = Base.datatype_alignment(T) + return (Core.sizeof(T) + al - 1) & -al + else + return nothing + end +end + +@inline foldlargs(op, x) = x +@inline foldlargs(op, x1, x2, xs...) = foldlargs(op, @return_if_reduced(op(x1, x2)), xs...) + +terminating_foldlargs(op, fallback) = fallback() +@inline function terminating_foldlargs(op, fallback::F, x1, x2, xs...) where {F} + acc = op(x1, x2) + acc isa Reduced && return unreduced(acc) + xs isa Tuple{} && return fallback() + return terminating_foldlargs(op, fallback, acc, xs...) +end + +@inline foldrunion(op, ::Type{T}, init) where {T} = + if T isa Union + acc = @return_if_reduced foldrunion(op, T.b, init) + foldrunion(op, T.a, acc) + else + op(T, init) + end + +@generated asntypes(::Type{T}) where {T} = + QuoteNode(foldrunion((S, types) -> (Val(S), types...), T, ())) + +struct UnionValue{T <: NTypes,NBytes} + types::T + data::NTuple{NBytes,UInt32} + typeid::UInt8 +end + +@noinline unreachable() = error("unreachable") + +@inline function unionvalue(::Type{T}, v::T) where {T} + if T isa Union + types = asntypes(T) + nbytes = foldrunion(T, 0) do S, n + Base.@_inline_meta + max(sizeof_aligned(S), n) + end + dest = Ref(ntuple(_ -> UInt32(0), Val(nbytes))) + GC.@preserve dest begin + unsafe_store!(Ptr{typeof(v)}(pointer_from_objref(dest)), v) + end + @inline function searchid((v, id), t) + if v isa valueof(t) + Reduced(id) + else + (v, id + 1) + end + end + typeid = terminating_foldlargs(searchid, unreachable, (v, 1), types...) + return UnionValue(types, dest[], UInt8(typeid)) + else + return v + end +end + +@noinline invalid_typeid() = error("invalid typeid") + +interpret(x) = x +@inline function interpret(uv::UnionValue) + data = uv.data + typeid = uv.typeid + @inline function _get(id, t) + if id == typeid + T = valueof(t) + ref = Ref(data) + GC.@preserve ref begin + v = unsafe_load(Ptr{T}(pointer_from_objref(ref))) + end + return Reduced(v) + else + id + 1 + end + end + return terminating_foldlargs(_get, invalid_typeid, 1, uv.types...) +end + +#= +@inline function Metal.shfl_recurse(op, uv::UnionValue) + data = map(op, uv.data) + typeid = op(uv.typeid) + return UnionValue(uv.types, data, typeid) +end +=# \ No newline at end of file diff --git a/ext/FoldsMetalExt/utils.jl b/ext/FoldsMetalExt/utils.jl new file mode 100644 index 0000000..a0d9e55 --- /dev/null +++ b/ext/FoldsMetalExt/utils.jl @@ -0,0 +1,38 @@ +macro manual_union_split(body::Expr, conditions...) + if !(body.head === :-> && length(body.args) == 2 && body.args[1] == :(())) + error( + "`@manual_union_split` is intended to be used with a `do` blockt", + " with no argumen", + ) + end + body = body.args[2] + ex = foldr(conditions; init = body) do c, ex + quote + if $c + $body + else + $ex + end + end + end + esc(ex) +end + +function ithtype(::Type{T}, i::Val) where {T} + S = foldrunion(T, Val(1)) do S, j + if j === i + S + else + if j isa Val + Val(valueof(j) + 1) + else + j + end + end + end + if S isa Type + return S + else + return Union{} + end +end