Skip to content

Commit 09ce663

Browse files
committed
Fix rewrite generic of sum(::Symbol; init)
1 parent ffcfd9a commit 09ce663

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/rewrite_generic.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,18 @@ function _rewrite_generic(stack::Expr, expr::Expr)
125125
# The summation has keyword arguments. We can deal with `init`, but
126126
# not any of the others.
127127
p = expr.args[2]
128-
if length(p.args) == 1 && _is_kwarg(p.args[1], :init)
128+
is_init = length(p.args) == 1 && _is_kwarg(p.args[1], :init)
129+
if is_init && expr.args[3] isa Expr
129130
# sum(iter ; init) form!
131+
# We rewrite only if `iter` is an Expr; if it's just a Symbol,
132+
# we don't enter this branch.
130133
root = gensym()
131134
init, _ = _rewrite_generic(stack, p.args[1].args[2])
132135
push!(stack.args, :($root = $init))
133136
return _rewrite_generic_generator(stack, :+, expr.args[3], root)
134-
else
135-
# We don't know how to deal with this
136-
return esc(expr), false
137137
end
138+
# We don't know how to deal with this
139+
return esc(expr), false
138140
else
139141
# Summations use :+ as the reduction operator.
140142
init_expr = expr.args[2].args[end]

test/rewrite_generic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,15 @@ function test_allocations_rewrite_mult()
509509
return
510510
end
511511

512+
function test_rewrite_init_symbol()
513+
x = Int[]
514+
y = MA.@rewrite(sum(x; init = 0), move_factors_into_sums = false)
515+
@test y == 0
516+
y = MA.@rewrite(sum(x, init = 0), move_factors_into_sums = false)
517+
@test y == 0
518+
return
519+
end
520+
512521
end # module
513522

514523
TestRewriteGeneric.runtests()

0 commit comments

Comments
 (0)