Skip to content

Commit 21d9f8f

Browse files
committed
Unify Symmetric/Hermitian tests
1 parent 4b22c58 commit 21d9f8f

File tree

1 file changed

+32
-60
lines changed

1 file changed

+32
-60
lines changed

test/gradcheck.jl

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ gradcheck(f, xs...) =
2525
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
2626
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
2727

28+
# utilities for using gradcheck with complex matrices
29+
_splitreim(A) = (real(A),)
30+
_splitreim(A::AbstractArray{<:Complex}) = reim(A)
31+
32+
_joinreim(A, B) = complex.(A, B)
33+
_joinreim(A) = A
34+
35+
_dropimaggrad(A) = Zygote.hook(real, A)
36+
2837
Random.seed!(0)
2938

3039
@test gradient(//, 2, 3) === (1//3, -2//9)
@@ -627,80 +636,43 @@ function _randmatseries(rng, ::typeof(atanh), T, n)
627636
return collect(tanh(Hermitian(_randmatseries(rng, tanh, T, n))))
628637
end
629638

639+
_symhermtype(::Type{<:Symmetric}) = Symmetric
640+
_symhermtype(::Type{<:Hermitian}) = Hermitian
641+
642+
function _gradtest_symherm(f, ST, A)
643+
gradtest(_splitreim(collect(A))...) do (args...)
644+
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
645+
return sum(_splitreim(B))
646+
end
647+
end
648+
630649
@testset "Hermitian/Symmetric power series functions" begin
650+
MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
651+
rng, N = MersenneTwister(123), 7
631652
@testset "$func(::RealHermSymComplexHerm)" for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt)
632653
f = eval(func)
633-
@testset "$func(::Symmetric{<:Real})" begin
634-
rng, N = MersenneTwister(123), 7
635-
A = Symmetric(_randmatseries(rng, f, Float64, N))
636-
@test gradtest(x->sum(reim(f(Symmetric(Zygote.hook(real, x))))), collect(A))
637-
y = Zygote.pullback(f, A)[1]
638-
y2 = f(A)
639-
@test y y2
640-
@testset "similar eigenvalues" begin
641-
λ, U = eigen(A)
642-
λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10
643-
A2 = U * Diagonal(λ) * U'
644-
@test gradtest(x->sum(reim(f(Symmetric(Zygote.hook(real, x))))), collect(A2))
645-
end
646-
if f (log, sqrt) # only defined for invertible matrices
647-
@testset "low rank" begin
648-
U = eigvecs(A)
649-
A2 = U * Diagonal([rand(rng), zeros(N-1)...]) * U'
650-
@test gradtest(x->sum(reim(f(Symmetric(Zygote.hook(real, x))))), collect(A2))
651-
end
652-
end
653-
end
654+
@testset "$func(::$MT)" for MT in MTs
655+
T = eltype(MT)
656+
ST = _symhermtype(MT)
657+
A = ST(_randmatseries(rng, f, T, N))
658+
λ, U = eigen(A)
659+
660+
@test _gradtest_symherm(f, ST, A)
654661

655-
@testset "$func(::Hermitian{<:Real})" begin
656-
rng, N = MersenneTwister(456), 7
657-
A = Hermitian(_randmatseries(rng, f, Float64, N))
658-
@test gradtest(x->sum(reim(f(Hermitian(Zygote.hook(real, x))))), collect(A))
659662
y = Zygote.pullback(f, A)[1]
660663
y2 = f(A)
661664
@test y y2
665+
662666
@testset "similar eigenvalues" begin
663-
λ, U = eigen(A)
664667
λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10
665668
A2 = U * Diagonal(λ) * U'
666-
@test gradtest(x->sum(reim(f(Hermitian(Zygote.hook(real, x))))), collect(A2))
669+
@test _gradtest_symherm(f, ST, A2)
667670
end
668-
if f (log, sqrt) # only defined for invertible matrices
669-
@testset "low rank" begin
670-
U = eigvecs(A)
671-
A2 = U * Diagonal([rand(rng), zeros(N-1)...]) * U'
672-
@test gradtest(x->sum(reim(f(Hermitian(Zygote.hook(real, x))))), collect(A2))
673-
end
674-
end
675-
end
676671

677-
@testset "$func(::Hermitian{<:Complex})" begin
678-
rng, N = MersenneTwister(789), 7
679-
A = Hermitian(_randmatseries(rng, f, Complex{Float64}, N))
680-
@test gradtest(reim(collect(A))...) do a,b
681-
B = f(Hermitian(complex.(a, b)))
682-
return real.(B) .+ 2 .* imag.(B)
683-
end
684-
y = Zygote.pullback(f, A)[1]
685-
y2 = f(A)
686-
@test y y2
687-
@testset "similar eigenvalues" begin
688-
λ, U = eigen(A)
689-
λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10
690-
A2 = Hermitian(U * Diagonal(λ) * U')
691-
@test gradtest(reim(collect(A2))...) do a,b
692-
B = f(Hermitian(complex.(a, b)))
693-
return real.(B) .+ 2 .* imag.(B)
694-
end
695-
end
696672
if f (log, sqrt) # only defined for invertible matrices
697673
@testset "low rank" begin
698-
U = eigvecs(A)
699-
A2 = U * Diagonal([rand(rng), zeros(N-1)...]) * U'
700-
@test gradtest(reim(collect(A2))...) do a,b
701-
B = f(Hermitian(complex.(a, b)))
702-
return real.(B) .+ 2 .* imag.(B)
703-
end
674+
A3 = U * Diagonal([rand(rng), zeros(N-1)...]) * U'
675+
@test _gradtest_symherm(f, ST, A3)
704676
end
705677
end
706678
end

0 commit comments

Comments
 (0)