Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ Transducers = "0.4.35"
julia = "1.6"

[extensions]
FoldsMetalExt = "Metal"
FoldsOnlineStatsBaseExt = "OnlineStatsBase"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Documenter", "Test"]

[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338"
56 changes: 56 additions & 0 deletions ext/FoldsMetalExt/FoldsMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module FoldsMetalExt

export MetalEx, CoalescedMetalEx

using Metal
using Metal: @allowscalar, mtlconvert, mtlfunction
using Core: Typeof
using InitialValues: InitialValue, asmonoid
#using UnionArrays: UnionArrays, UnionVector
using Transducers:
@return_if_reduced,
Executor,
Map,
Reduced,
Transducer,
Transducers,
combine,
complete,
next,
opcompose,
reduced,
start,
transduce,
unreduced

# TODO: Don't import internals from Transducers:
using Transducers:
AbstractReduction,
DefaultInit,
DefaultInitOf,
EmptyResultError,
IdentityTransducer,
Reduction,
_reducingfunction,
completebasecase,
extract_transducer,
foldl_basecase

include("utils.jl")
include("kernels.jl")
include("unionvalues.jl")
include("shfl.jl")
include("api.jl")
include("introspection.jl")

#= Use README as the docstring of the module:
@doc let path = joinpath(dirname(@__DIR__), "README.md")
include_dependency(path)
doc = read(path, String)
doc = replace(doc, r"^```julia"m => "```jldoctest README")
doc = replace(doc, "(https://juliafolds.github.io/FoldsMetalExt.jl/dev/examples/)" => "(@ref examples-toc)")
doc
end FoldsMetalExt
=#

end
42 changes: 42 additions & 0 deletions ext/FoldsMetalExt/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
MetalEx()

A fold executor implemented using Metal.jl.

For more information about executor, see
[Transducers.jl's glossary section](https://juliafolds.github.io/Transducers.jl/dev/explanation/glossary/#glossary-executor)
and
[FLoops.jl's API section](https://juliafolds.github.io/FLoops.jl/dev/reference/api/#executor).

# Examples
```jldoctest
julia> using FoldsMetalExt, Folds

julia> Folds.sum(1:10, MetalEx())
55
```
"""
struct MetalEx{K} <: Executor
kwargs::K
end

popsimd(; simd = nothing, kwargs...) = kwargs

Transducers.transduce(xf, rf::RF, init, xs, exc::MetalEx) where {RF} =
_transduce_metal(xf, rf, init, xs; popsimd(; exc.kwargs...)...)

Transducers.executor_type(::MtlArray) = MetalEx

"""
CoalescedMetalEx()

A fold executor implemented using Metal.jl. It uses coalesced memory access
while supporting non-commutative loops. It can be faster but more limited than
`MetalEx`.
"""
struct CoalescedMetalEx{K} <: Executor
kwargs::K
end

Transducers.transduce(xf, rf::RF, init, xs, exc::CoalescedMetalEx) where {RF} =
transduce_shfl(xf, rf, init, xs; popsimd(; exc.kwargs...)...)
33 changes: 33 additions & 0 deletions ext/FoldsMetalExt/introspection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
const RUN_ON_HOST_IF_NORETURN = Ref(false)

struct FailedInference <: Exception
f::Any
kernel_args::Tuple
kernel_return::Any
host_args::Tuple
host_return::Any
end

function Base.showerror(io::IO, err::FailedInference)
print(io, FailedInference, ": ")
if err.kernel_return === Union{}
print(io, "Kernel is inferred to not return (return type is `Union{}`)")
else
print(io, "Kernel is inferred to return invalid type: ", err.kernel_return)
end
if err.kernel_return === Union{} && err.host_return !== Union{}
println(io)
print(io, "Note: on the host, the return type is inferred as ", err.host_return)
end
println(io)
printstyled(io, "HINT"; bold = true, color = :light_black)
printstyled(
io,
": if this exception is caught as `err``, use `Metal.code_typed(err)` to",
" introspect the erronous code.";
color = :light_black,
)
end

Metal.code_typed(err::FailedInference; options...) =
Metal.code_typed(err.f, typeof(err.kernel_args); options...)
Loading
Loading