@@ -33,9 +33,104 @@ struct CombineOpSpec <: OpSpec
3333end
3434
3535CombineOpSpec (args:: Vector{Any} ) = CombineOpSpec (args, Symbol[])
36+ macroname (:: CombineOpSpec ) = Symbol (" @combine" )
37+
38+ # Without a macro like `@completebasecase`, it'd be confusing to have an
39+ # expression such as
40+ #
41+ # @floop begin
42+ # ...
43+ # for x in xs
44+ # ... # executed in parallel loop body
45+ # end
46+ # for y in ys # executed in completebasecase hook
47+ # ...
48+ # end
49+ # ...
50+ # end
51+ #
52+ # i.e., two similar loops have drastically different semantics. The difference
53+ # can be clarified by using the syntax:
54+ #
55+ # @floop begin
56+ # ...
57+ # for x in xs
58+ # ... # executed in parallel loop body
59+ # end
60+ # @completebasecase begin
61+ # for y in ys # executed in completebasecase hook
62+ # ...
63+ # end
64+ # end
65+ # ...
66+ # end
67+ """
68+ @completebasecase ex
69+
70+ Evaluate expression `ex` at the end of each basecase. The expression `ex` can
71+ only refer to the variables declared by `@init`.
72+
73+ `@completebasecase` can be omitted if `ex` does not contain a `for` loop.
74+
75+ # Examples
76+ ```jldoctest
77+ julia> using FLoops
78+
79+ julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end];
80+
81+ julia> @floop begin
82+ @init hist = zeros(Int, 10)
83+ for c in pidigits
84+ i = c - '0' + 1
85+ hist[i] += 1
86+ end
87+ @completebasecase begin
88+ j = 0
89+ y = 0
90+ for (i, x) in pairs(hist) # pretending we don't have `argmax`
91+ if x > y
92+ j = i
93+ y = x
94+ end
95+ end
96+ peaks = [j]
97+ nchunks = [sum(hist)]
98+ end
99+ @combine hist .+= _
100+ @combine peaks = append!(_, _)
101+ @combine nchunks = append!(_, _)
102+ end
103+ ```
104+ """
105+ macro completebasecase (ex)
106+ ex = Expr (:block , __source__, ex)
107+ :(throw ($ (CompleteBasecaseOp (ex))))
108+ end
109+
110+ struct CompleteBasecaseOp
111+ ex:: Expr
112+ end
113+
114+ function extract_spec (ex)
115+ @match ex begin
116+ Expr (:call , throw′, spec:: ReduceOpSpec ) => spec
117+ Expr (:call , throw′, spec:: CombineOpSpec ) => spec
118+ Expr (:call , throw′, spec:: InitSpec ) => spec
119+ Expr (:call , throw′, spec:: CompleteBasecaseOp ) => spec
120+ _ => nothing
121+ end
122+ end
123+
124+ isa_spec (:: Type{T} ) where {T} = x -> extract_spec (x) isa T
36125
37126function combine_parallel_loop (ctx:: MacroContext , ex:: Expr , simd, executor = nothing )
38- iterspec, body, ansvar, pre, post = destructure_loop_pre_post (ex)
127+ iterspec, body, ansvar, pre, post = destructure_loop_pre_post (
128+ ex;
129+ multiple_loop_note = string (
130+ " Wrap the expressions after the first loop (parallel loop) with" ,
131+ " `@completebasecase`." ,
132+ ),
133+ )
39134 @assert ansvar == :_
40135
41136 parallel_loop_ex = @match iterspec begin
@@ -50,15 +145,6 @@ function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = not
50145 return parallel_loop_ex
51146end
52147
53- function extract_spec (ex)
54- @match ex begin
55- Expr (:call , throw′, spec:: ReduceOpSpec ) => spec
56- Expr (:call , throw′, spec:: CombineOpSpec ) => spec
57- Expr (:call , throw′, spec:: InitSpec ) => spec
58- _ => nothing
59- end
60- end
61-
62148function as_parallel_combine_loop (
63149 ctx:: MacroContext ,
64150 pre:: Vector ,
@@ -70,6 +156,7 @@ function as_parallel_combine_loop(
70156 executor,
71157)
72158 @assert simd in (false , true , :ivdep )
159+ foreach (disalow_raw_for_loop_without_completebasecase, post)
73160
74161 init_exprs = []
75162 all_rf_accs = []
@@ -89,11 +176,27 @@ function as_parallel_combine_loop(
89176 # `next` reducing step function:
90177 base_accs = mapcat (identity, all_rf_accs)
91178
92- firstcombine = something (
93- findfirst (x -> extract_spec (x) isa CombineOpSpec, post),
94- lastindex (post) + 1 ,
95- )
179+ firstcombine = something (findfirst (isa_spec (CombineOpSpec), post), lastindex (post) + 1 )
180+
96181 completebasecase_exprs = post[firstindex (post): firstcombine- 1 ]
182+ if any (isa_spec (CompleteBasecaseOp), completebasecase_exprs)
183+ # If `CompleteBasecaseOp` is used, this must be the only expression:
184+ let exprs = [x for x in completebasecase_exprs if ! (x isa LineNumberNode)],
185+ spec = extract_spec (exprs[1 ])
186+
187+ if spec isa CompleteBasecaseOp && length (exprs) == 1
188+ completebasecase_exprs = Any[spec. ex]
189+ elseif all (isa_spec (CompleteBasecaseOp), exprs)
190+ error (" Only one `@completebasecase` can be used. got:\n " , join (exprs, " \n " ))
191+ else
192+ error (
193+ " `@completebasecase` cannot be mixed with other expressions." ,
194+ " Put everything in `@completebasecase begin ... end`. got:\n " ,
195+ join (exprs, " \n " ),
196+ )
197+ end
198+ end
199+ end
97200
98201 left_accs = []
99202 right_accs = []
@@ -104,7 +207,8 @@ function as_parallel_combine_loop(
104207 spec = extract_spec (ex)
105208 if ! (spec isa CombineOpSpec)
106209 error (
107- " non-`@combine` expressions must be placed between `for` loop and the first `@combine` expression: " ,
210+ " non-`@combine` expressions must be placed between `for` loop and the" ,
211+ " first `@combine` expression: " ,
108212 spec,
109213 )
110214 end
@@ -279,3 +383,22 @@ function process_combine_op_spec(
279383 # TODO : use accurate line number from `@combine`
280384 return (; left = left, right = right, combine_body = combine_body)
281385end
386+
387+ function disalow_raw_for_loop_without_completebasecase (@nospecialize (ex))
388+ ex isa Expr || return
389+ extract_spec (ex) === nothing || return
390+ _disalow_raw_for_loop (ex)
391+ end
392+
393+ function _disalow_raw_for_loop (@nospecialize (ex))
394+ ex isa Expr || return
395+ if isexpr (ex, :for )
396+ error (
397+ " `@floop begin ... end` can only contain one `for` loop." ,
398+ " Use `@completebasecase begin ... end` to wrap the code after the parallel" ,
399+ " loop, including the `for` loop. Got:\n " ,
400+ ex,
401+ )
402+ end
403+ foreach (_disalow_raw_for_loop, ex. args)
404+ end
0 commit comments