diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index b29aa18007b..0b6e8ec2ff9 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -856,21 +856,22 @@ end function Base.sum( f::F, - x::Union{DenseAxisArray,DenseAxisArrayView}; + x::Union{DenseAxisArray{T},DenseAxisArrayView{T}}; dims = Colon(), -) where {F<:Function} - if dims == Colon() - return sum(f(xi) for xi in x) + init = zero(T), +) where {F<:Function,T} + if dims != Colon() + return error( + "`sum(x::DenseAxisArray; dims)` is not supported. Convert the array " * + "to an `Array` using `sum(Array(x); dims=$dims)`, or use an explicit " * + "for-loop summation instead.", + ) end - return error( - "`sum(x::DenseAxisArray; dims)` is not supported. Convert the array " * - "to an `Array` using `sum(Array(x); dims=$dims)`, or use an explicit " * - "for-loop summation instead.", - ) + return sum(f(xi) for xi in x; init) end -function Base.sum(x::Union{DenseAxisArray,DenseAxisArrayView}; dims = Colon()) - return sum(identity, x; dims = dims) +function Base.sum(x::Union{DenseAxisArray,DenseAxisArrayView}; kwargs...) + return sum(identity, x; kwargs...) end function Base.promote_shape(a::DenseAxisArray, b::DenseAxisArray) diff --git a/test/Containers/test_DenseAxisArray.jl b/test/Containers/test_DenseAxisArray.jl index 0a39418bbdb..243c7e01df0 100644 --- a/test/Containers/test_DenseAxisArray.jl +++ b/test/Containers/test_DenseAxisArray.jl @@ -990,4 +990,17 @@ function test_issue_4053() return end +function test_sum_init() + x = Containers.@container([i in Int[]], i) + @test sum(x) == 0 + @test sum(x; init = 1) == 1 + y = Containers.@container([i in BigInt[]], i) + y_1 = sum(y) + @test y_1 == 0 + @test y_1 isa BigInt + y_2 = sum(y; init = 0) + @test y_2 === 0 + return +end + end # module