diff --git a/src/reduce.jl b/src/reduce.jl index 9a909ee..fbd4c9a 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -22,7 +22,8 @@ reduce_op(op::AddSubMul) = add_sub_op(op) reduce_op(::typeof(add_dot)) = + -neutral_element(::typeof(+), T::Type) = zero(T) +neutral_element(::typeof(+), ::Type{T}) where {T} = zero(T) +neutral_element(::typeof(+), ::Type{T}) where {T<:AbstractArray} = Zero() map_op(::AddSubMul) = * diff --git a/test/interface.jl b/test/interface.jl index cfe56ab..ac5e120 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -215,3 +215,21 @@ end @test ret == reshape([3.0], 1, 1) @test y == reshape([1.0], 1, 1, 1) end + +@testset "issue_318_neutral_element" begin + a = rand(3) + A = [rand(2, 2) for _ in 1:3] + @test_throws DimensionMismatch MA.operate(LinearAlgebra.dot, a, A) + y = a' * A + @test isapprox(MA.fused_map_reduce(MA.add_mul, a', A), y) + z = MA.operate(LinearAlgebra.dot, Int[], Int[]) + @test iszero(z) && z isa Int + z = MA.operate(LinearAlgebra.dot, BigInt[], Int[]) + @test iszero(z) && z isa BigInt + z = MA.operate(LinearAlgebra.dot, Int[], Float64[]) + @test iszero(z) && z isa Float64 + z = MA.operate(LinearAlgebra.dot, Matrix{Int}[], Matrix{Float64}[]) + @test iszero(z) && z isa Float64 + @test MA.fused_map_reduce(MA.add_mul, Matrix{Int}[], Float64[]) isa MA.Zero + @test MA.fused_map_reduce(MA.add_mul, Float64[], Matrix{Int}[]) isa MA.Zero +end