Skip to content

Sketching new API: FLoops.@combine #114

@tkf

Description

@tkf

It is sometimes useful to separately define reduction inside of basecase and reduction across basecases. A typical example is histogram computation. Currently, implementing this requires coming up with a gadget like OneHotVector:

@floop ex for i in indices
    @reduce h .+= OneHotVector(i => 1, n)
end

h :: Vector{Int}  # computed histogram

But, if we don't have OneHotVector, it's tricky for users to define this. It may be a good idea to support more verbose but controllable syntax.

New syntax

The idea is to add a new syntax, for example:

@floop begin
    @init buf = zeros(Int, 10)  # per basecase initialization
    for x in xs
        bin = max(1, min(10, floor(Int, x)))
        buf[bin] += 1  # reduction within basecase (no syntax)
    end
    @combine h .+= buf  # reduction across basecases
end

h :: Vector{Int}  # computed histogram

The new macro @combine takes the same expressions as @reduce does. However, it is not executed inside of the loop body like @reduce.

This is lowered to something equivalent to

function op!!((_, h), (is_basecase, x))
    if is_basecase
        # The left argument is the `buf` inside of basecase:
        buf = h

        # Fused loop body:
        bin = max(1, min(10, floor(Int, x)))
        buf[bin] += 1
    else
        # The right argument is the `buf` when combining sub-solutions:
        buf = x

        # `@combine` instructions:
        h .+= buf
    end
    return (false, h)
end

init() = (false, zeros(Int, 10))

Folds.reduce(op!!, ((true, x) for x in xs); init = OnInit(init))

(with a extra care so that the compiler can eliminate the branch in op! and the base case is compiled down to a straight loop)

The name @combine reflects the Transducers API Transducers.combine.

Comparison

Example: collatz_histogram

https://juliafolds.github.io/data-parallelism/tutorials/quick-introduction/#practical_example_histogram_of_stopping_time_of_collatz_function is an example of using FLoops.jl to compute histogram when you don't know the upper bound:

using FLoops
using MicroCollections: SingletonDict

maxkey(xs::AbstractVector) = lastindex(xs)
maxkey(xs::SingletonDict) = first(keys(xs))

function collatz_histogram(xs, executor = ThreadedEx())
    @floop executor for x in xs
        n = collatz_stopping_time(x)
        n > 0 || continue
        obs = SingletonDict(n => 1)
        @reduce() do (hist = Int[]; obs)
            l = length(hist)
            m = maxkey(obs)  # obs is a Vector or SingletonDict
            if l < m
                # Stretch `hist` so that the merged result fits in it.
                resize!(hist, m)
                fill!(view(hist, l+1:m), 0)
            end
            # Merge `obs` into `hist`:
            for (k, v) in pairs(obs)
                @inbounds hist[k] += v
            end
        end
    end
    return hist
end

This can be written as

using FLoops

function collatz_histogram(xs, executor = ThreadedEx())
    @floop executor begin
        @init buf = Int[]

        for x in xs
            n = collatz_stopping_time(x)
            n > 0 || continue

            l = length(buf)
            if l < n
                resize!(buf, n)
                fill!(view(buf, l+1:n), 0)
            end
            @inbounds buf[k] += 1
        end

        @combine() do (hist; buf)
            l = length(hist)
            n = length(buf)
            if n > l
                resize(hist, n)
                fill!(view(hist, l+1:n), 0)
            end
            @views hist[1:n] .+= buf
        end
    end
    return hist
end

Compared to @reduce version, @combine version has more repetition (for resize!). However, it can be written without coming up with the abstraction like maxkey and also without knowing SingletonDict.

Example: using mul!

https://juliafolds.github.io/data-parallelism/tutorials/mutations/#advanced_fusing_multiplication_and_addition_in_base_cases shows how to use 5-arg mul!

using FLoops
using LinearAlgebra: mul!

@floop for (A, B) in zip(As, Bs)
    C = (A, B)
    @reduce() do (S = zero(A); C)
        if C isa Tuple  # base case
            mul!(S, C[1], C[2], 1, 1)
        else            # combining base cases
            S .+= C
        end
    end
end

This can be written as

using FLoops
using LinearAlgebra: mul!

@floop begin
    @init C = zero(As[1])

    for (A, B) in zip(As, Bs)
        mul!(C, A, B, 1, 1)
    end

    @combine S .+= C
end

This is much cleaner to use @combine than @reduce.

Discussion/feedbacks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions