Skip to content

Commit a5d109c

Browse files
committed
Add @completebasecase macro
1 parent bcaf1ad commit a5d109c

File tree

7 files changed

+251
-21
lines changed

7 files changed

+251
-21
lines changed

src/FLoops.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@ module FLoops
2020
doc
2121
end FLoops
2222

23-
export @floop, @init, @combine, @reduce, DistributedEx, SequentialEx, ThreadedEx
23+
#! format: off
24+
export @floop,
25+
@init,
26+
@combine,
27+
@reduce,
28+
@completebasecase,
29+
DistributedEx,
30+
SequentialEx,
31+
ThreadedEx
32+
#! format: on
2433

2534
using BangBang.Extras: broadcast_inplace!!
2635
using BangBang: materialize!!, push!!

src/combine.jl

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,104 @@ struct CombineOpSpec <: OpSpec
3333
end
3434

3535
CombineOpSpec(args::Vector{Any}) = CombineOpSpec(args, Symbol[])
36+
macroname(::CombineOpSpec) = Symbol("@combine")
37+
38+
# Without a macro like `@completebasecase`, it'd be confusing to have an
39+
# expression such as
40+
#
41+
# @floop begin
42+
# ...
43+
# for x in xs
44+
# ... # executed in parallel loop body
45+
# end
46+
# for y in ys # executed in completebasecase hook
47+
# ...
48+
# end
49+
# ...
50+
# end
51+
#
52+
# i.e., two similar loops have drastically different semantics. The difference
53+
# can be clarified by using the syntax:
54+
#
55+
# @floop begin
56+
# ...
57+
# for x in xs
58+
# ... # executed in parallel loop body
59+
# end
60+
# @completebasecase begin
61+
# for y in ys # executed in completebasecase hook
62+
# ...
63+
# end
64+
# end
65+
# ...
66+
# end
67+
"""
68+
@completebasecase ex
69+
70+
Evaluate expression `ex` at the end of each basecase. The expression `ex` can
71+
only refer to the variables declared by `@init`.
72+
73+
`@completebasecase` can be omitted if `ex` does not contain a `for` loop.
74+
75+
# Examples
76+
```jldoctest
77+
julia> using FLoops
78+
79+
julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end];
80+
81+
julia> @floop begin
82+
@init hist = zeros(Int, 10)
83+
for c in pidigits
84+
i = c - '0' + 1
85+
hist[i] += 1
86+
end
87+
@completebasecase begin
88+
j = 0
89+
y = 0
90+
for (i, x) in pairs(hist) # pretending we don't have `argmax`
91+
if x > y
92+
j = i
93+
y = x
94+
end
95+
end
96+
peaks = [j]
97+
nchunks = [sum(hist)]
98+
end
99+
@combine hist .+= _
100+
@combine peaks = append!(_, _)
101+
@combine nchunks = append!(_, _)
102+
end
103+
```
104+
"""
105+
macro completebasecase(ex)
106+
ex = Expr(:block, __source__, ex)
107+
:(throw($(CompleteBasecaseOp(ex))))
108+
end
109+
110+
struct CompleteBasecaseOp
111+
ex::Expr
112+
end
113+
114+
function extract_spec(ex)
115+
@match ex begin
116+
Expr(:call, throw′, spec::ReduceOpSpec) => spec
117+
Expr(:call, throw′, spec::CombineOpSpec) => spec
118+
Expr(:call, throw′, spec::InitSpec) => spec
119+
Expr(:call, throw′, spec::CompleteBasecaseOp) => spec
120+
_ => nothing
121+
end
122+
end
123+
124+
isa_spec(::Type{T}) where {T} = x -> extract_spec(x) isa T
36125

37126
function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = nothing)
38-
iterspec, body, ansvar, pre, post = destructure_loop_pre_post(ex)
127+
iterspec, body, ansvar, pre, post = destructure_loop_pre_post(
128+
ex;
129+
multiple_loop_note = string(
130+
" Wrap the expressions after the first loop (parallel loop) with",
131+
" `@completebasecase`.",
132+
),
133+
)
39134
@assert ansvar == :_
40135

41136
parallel_loop_ex = @match iterspec begin
@@ -50,15 +145,6 @@ function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = not
50145
return parallel_loop_ex
51146
end
52147

53-
function extract_spec(ex)
54-
@match ex begin
55-
Expr(:call, throw′, spec::ReduceOpSpec) => spec
56-
Expr(:call, throw′, spec::CombineOpSpec) => spec
57-
Expr(:call, throw′, spec::InitSpec) => spec
58-
_ => nothing
59-
end
60-
end
61-
62148
function as_parallel_combine_loop(
63149
ctx::MacroContext,
64150
pre::Vector,
@@ -70,6 +156,7 @@ function as_parallel_combine_loop(
70156
executor,
71157
)
72158
@assert simd in (false, true, :ivdep)
159+
foreach(disalow_raw_for_loop_without_completebasecase, post)
73160

74161
init_exprs = []
75162
all_rf_accs = []
@@ -89,11 +176,27 @@ function as_parallel_combine_loop(
89176
# `next` reducing step function:
90177
base_accs = mapcat(identity, all_rf_accs)
91178

92-
firstcombine = something(
93-
findfirst(x -> extract_spec(x) isa CombineOpSpec, post),
94-
lastindex(post) + 1,
95-
)
179+
firstcombine = something(findfirst(isa_spec(CombineOpSpec), post), lastindex(post) + 1)
180+
96181
completebasecase_exprs = post[firstindex(post):firstcombine-1]
182+
if any(isa_spec(CompleteBasecaseOp), completebasecase_exprs)
183+
# If `CompleteBasecaseOp` is used, this must be the only expression:
184+
let exprs = [x for x in completebasecase_exprs if !(x isa LineNumberNode)],
185+
spec = extract_spec(exprs[1])
186+
187+
if spec isa CompleteBasecaseOp && length(exprs) == 1
188+
completebasecase_exprs = Any[spec.ex]
189+
elseif all(isa_spec(CompleteBasecaseOp), exprs)
190+
error("Only one `@completebasecase` can be used. got:\n", join(exprs, "\n"))
191+
else
192+
error(
193+
"`@completebasecase` cannot be mixed with other expressions.",
194+
" Put everything in `@completebasecase begin ... end`. got:\n",
195+
join(exprs, "\n"),
196+
)
197+
end
198+
end
199+
end
97200

98201
left_accs = []
99202
right_accs = []
@@ -104,7 +207,8 @@ function as_parallel_combine_loop(
104207
spec = extract_spec(ex)
105208
if !(spec isa CombineOpSpec)
106209
error(
107-
"non-`@combine` expressions must be placed between `for` loop and the first `@combine` expression: ",
210+
"non-`@combine` expressions must be placed between `for` loop and the",
211+
" first `@combine` expression: ",
108212
spec,
109213
)
110214
end
@@ -279,3 +383,22 @@ function process_combine_op_spec(
279383
# TODO: use accurate line number from `@combine`
280384
return (; left = left, right = right, combine_body = combine_body)
281385
end
386+
387+
function disalow_raw_for_loop_without_completebasecase(@nospecialize(ex))
388+
ex isa Expr || return
389+
extract_spec(ex) === nothing || return
390+
_disalow_raw_for_loop(ex)
391+
end
392+
393+
function _disalow_raw_for_loop(@nospecialize(ex))
394+
ex isa Expr || return
395+
if isexpr(ex, :for)
396+
error(
397+
"`@floop begin ... end` can only contain one `for` loop.",
398+
" Use `@completebasecase begin ... end` to wrap the code after the parallel",
399+
" loop, including the `for` loop. Got:\n",
400+
ex,
401+
)
402+
end
403+
foreach(_disalow_raw_for_loop, ex.args)
404+
end

src/macro.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ end
6969
Goto{label}(acc::T) where {label,T} = Goto{label,T}(acc)
7070
gotoexpr(label::Symbol) = :($Goto{$(QuoteNode(label))})
7171

72-
function destructure_loop_pre_post(ex)
72+
function destructure_loop_pre_post(ex; multiple_loop_note = "")
7373
pre = post = Union{}[]
7474
ansvar = :_
7575
if isexpr(ex, :for)
@@ -83,7 +83,13 @@ function destructure_loop_pre_post(ex)
8383
pre = args[1:i-1]
8484
post = args[i+1:end]
8585
if find_first_for_loop(post) !== nothing
86-
throw(ArgumentError("Multiple top-level `for` loop found in:\n$ex"))
86+
msg = string(
87+
"Multiple top-level `for` loops found.",
88+
multiple_loop_note,
89+
" Given expression:\n",
90+
ex,
91+
)
92+
throw(ArgumentError(msg))
8793
end
8894
else
8995
throw(ArgumentError("Unsupported expression:\n$ex"))

src/reduce.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ struct ReduceOpSpec <: OpSpec
9494
end
9595

9696
ReduceOpSpec(args::Vector{Any}) = ReduceOpSpec(args, Symbol[])
97+
macroname(::ReduceOpSpec) = Symbol("@reduce")
9798

9899
"""
99100
@init begin
@@ -854,9 +855,6 @@ struct _FLoopInit end
854855
transduce(IdentityTransducer(), rf, DefaultInit, coll, maybe_set_simd(exc, simd)),
855856
)
856857

857-
macroname(::ReduceOpSpec) = Symbol("@reduce")
858-
macroname(::CombineOpSpec) = Symbol("@combine")
859-
860858
function Base.print(io::IO, spec::OpSpec)
861859
# TODO: print as `do` block
862860
print(io, macroname(spec), "(")

test/FLoopsTests/src/FLoopsTests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module FLoopsTests
22

33
using Test
44

5+
include("utils.jl")
6+
57
for file in
68
sort([file for file in readdir(@__DIR__) if match(r"^test_.*\.jl$", file) !== nothing])
79
include(file)

test/FLoopsTests/src/test_combine.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using MicroCollections
55
using StaticArrays
66
using Test
77

8+
using ..Utils: @macroexpand_error
9+
810
function count_ints_two_pass(indices, ex = nothing)
911
l, h = extrema(indices)
1012
n = h - l + 1
@@ -18,10 +20,28 @@ function count_ints_two_pass(indices, ex = nothing)
1820
return hist
1921
end
2022

23+
valueof(::Val{x}) where {x} = x
24+
25+
function count_ints_two_pass2(indices, ex = nothing)
26+
l, h = extrema(indices)
27+
n = Val(h - l + 1)
28+
@floop ex begin
29+
@init hist = zero(MVector{valueof(n),Int32})
30+
for i in indices
31+
hist[i-l+1] += 1
32+
end
33+
@completebasecase hist = SVector(hist)
34+
@combine hist .+= _
35+
end
36+
return hist
37+
end
38+
2139
function test_count_ints_two_pass()
2240
@testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)]
2341
@test count_ints_two_pass(1:3, ex) == [1, 1, 1]
2442
@test count_ints_two_pass([1, 2, 4, 1], ex) == [2, 1, 0, 1]
43+
@test count_ints_two_pass2(1:3, ex) == [1, 1, 1]
44+
@test count_ints_two_pass2([1, 2, 4, 1], ex) == [2, 1, 0, 1]
2545
end
2646
end
2747

@@ -94,4 +114,59 @@ function test_count_positive_ints()
94114
end
95115
end
96116

117+
function test_error_one_for_loop1()
118+
err = @macroexpand_error @floop begin
119+
@init a = nothing
120+
for x in xs
121+
end
122+
for y in ys
123+
end
124+
end
125+
@test err isa Exception
126+
msg = sprint(showerror, err)
127+
@test occursin("Wrap the expressions after the first loop", msg)
128+
end
129+
130+
function test_error_one_for_loop2()
131+
err = @macroexpand_error @floop begin
132+
@init a = nothing
133+
for x in xs
134+
end
135+
function f()
136+
for y in ys
137+
end
138+
end
139+
end
140+
@test err isa Exception
141+
msg = sprint(showerror, err)
142+
@test occursin("can only contain one `for` loop", msg)
143+
end
144+
145+
function test_error_mixing_plain_expr_and_completebasecase()
146+
err = @macroexpand_error @floop begin
147+
@init a = nothing
148+
for x in xs
149+
end
150+
@completebasecase for y in ys
151+
end
152+
f(ys)
153+
end
154+
@test err isa Exception
155+
msg = sprint(showerror, err)
156+
@test occursin("cannot be mixed with other expressions", msg)
157+
end
158+
159+
function test_error_two_completebasecase_macro_calls()
160+
err = @macroexpand_error @floop begin
161+
@init a = nothing
162+
for x in xs
163+
end
164+
@completebasecase nothing
165+
@completebasecase nothing
166+
end
167+
@test err isa Exception
168+
msg = sprint(showerror, err)
169+
@test occursin("Only one `@completebasecase` can be used", msg)
170+
end
171+
97172
end # module

test/FLoopsTests/src/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module Utils
2+
3+
struct NoError end
4+
5+
macro macroexpand_error(ex)
6+
@gensym err
7+
quote
8+
try
9+
$Base.@eval $Base.@macroexpand $ex
10+
$NoError()
11+
catch $err
12+
$err
13+
end
14+
end |> esc
15+
end
16+
17+
end # module

0 commit comments

Comments
 (0)