Skip to content

Commit 8b6985a

Browse files
authored
[Containers] fix support for sum(::DenseAxisArray; init) (#4085)
1 parent 5cdd88b commit 8b6985a

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

src/Containers/DenseAxisArray.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -856,21 +856,22 @@ end
856856

857857
function Base.sum(
858858
f::F,
859-
x::Union{DenseAxisArray,DenseAxisArrayView};
859+
x::Union{DenseAxisArray{T},DenseAxisArrayView{T}};
860860
dims = Colon(),
861-
) where {F<:Function}
862-
if dims == Colon()
863-
return sum(f(xi) for xi in x)
861+
init = zero(T),
862+
) where {F<:Function,T}
863+
if dims != Colon()
864+
return error(
865+
"`sum(x::DenseAxisArray; dims)` is not supported. Convert the array " *
866+
"to an `Array` using `sum(Array(x); dims=$dims)`, or use an explicit " *
867+
"for-loop summation instead.",
868+
)
864869
end
865-
return error(
866-
"`sum(x::DenseAxisArray; dims)` is not supported. Convert the array " *
867-
"to an `Array` using `sum(Array(x); dims=$dims)`, or use an explicit " *
868-
"for-loop summation instead.",
869-
)
870+
return sum(f(xi) for xi in x; init)
870871
end
871872

872-
function Base.sum(x::Union{DenseAxisArray,DenseAxisArrayView}; dims = Colon())
873-
return sum(identity, x; dims = dims)
873+
function Base.sum(x::Union{DenseAxisArray,DenseAxisArrayView}; kwargs...)
874+
return sum(identity, x; kwargs...)
874875
end
875876

876877
function Base.promote_shape(a::DenseAxisArray, b::DenseAxisArray)

test/Containers/test_DenseAxisArray.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,4 +990,17 @@ function test_issue_4053()
990990
return
991991
end
992992

993+
function test_sum_init()
994+
x = Containers.@container([i in Int[]], i)
995+
@test sum(x) == 0
996+
@test sum(x; init = 1) == 1
997+
y = Containers.@container([i in BigInt[]], i)
998+
y_1 = sum(y)
999+
@test y_1 == 0
1000+
@test y_1 isa BigInt
1001+
y_2 = sum(y; init = 0)
1002+
@test y_2 === 0
1003+
return
1004+
end
1005+
9931006
end # module

0 commit comments

Comments
 (0)