Skip to content

Commit 97b4471

Browse files
committed
Simplify hermitian eigen/eigvals tests
1 parent 21d9f8f commit 97b4471

File tree

1 file changed

+33
-83
lines changed

1 file changed

+33
-83
lines changed

test/gradcheck.jl

Lines changed: 33 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -538,114 +538,64 @@ end
538538
end
539539
end
540540

541-
@testset "eigen(::RealHermSymComplexHerm)" begin
542-
@testset "eigen(::Symmetric{<:Real})" begin
543-
rng, N = MersenneTwister(123), 7
544-
A = Symmetric(randn(rng, N, N))
545-
@test gradtest(collect(A)) do (x)
546-
d, Q = eigen(Symmetric(x))
547-
return Q * Diagonal(exp.(d)) * transpose(Q)
548-
end
549-
y = Zygote.pullback(eigen, A)[1]
550-
y2 = eigen(A)
551-
@test y.values y2.values
552-
@test y.vectors y2.vectors
553-
@testset "low rank" begin
554-
U = eigvecs(A)
555-
A2 = Symmetric(U * Diagonal([randn(rng), zeros(N-1)...]) * U')
556-
@test_broken gradtest(collect(A2)) do (x)
557-
d, Q = eigen(Symmetric(x))
558-
return Q * Diagonal(exp.(d)) * transpose(Q)
559-
end
560-
end
561-
end
541+
_symhermtype(::Type{<:Symmetric}) = Symmetric
542+
_symhermtype(::Type{<:Hermitian}) = Hermitian
562543

563-
@testset "eigen(::Hermitian{<:Real})" begin
564-
rng, N = MersenneTwister(456), 7
565-
A = Hermitian(randn(rng, N, N))
566-
@test gradtest(collect(A)) do (x)
567-
d, Q = eigen(Hermitian(x))
568-
return Q * Diagonal(exp.(d)) * transpose(Q)
569-
end
570-
y = Zygote.pullback(eigen, A)[1]
571-
y2 = eigen(A)
572-
@test y.values y2.values
573-
@test y.vectors y2.vectors
574-
@testset "low rank" begin
575-
U = eigvecs(A)
576-
A2 = Hermitian(U * Diagonal([randn(rng), zeros(N-1)...]) * U')
577-
@test_broken gradtest(collect(A2)) do (x)
578-
d, Q = eigen(Hermitian(x))
579-
return Q * Diagonal(exp.(d)) * transpose(Q)
580-
end
581-
end
544+
function _gradtest_symherm(f, ST, A)
545+
gradtest(_splitreim(collect(A))...) do (args...)
546+
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
547+
return sum(_splitreim(B))
582548
end
549+
end
583550

584-
@testset "eigen(::Hermitian{<:Complex})" begin
585-
rng, N = MersenneTwister(789), 7
586-
A = Hermitian(randn(rng, ComplexF64, N, N))
587-
@test gradtest(reim(collect(A))...) do a,b
588-
d, U = eigen(Hermitian(complex.(a, b)))
589-
X = U * Diagonal(exp.(d)) * U'
590-
return real.(X) .+ imag.(X)
551+
@testset "eigen(::RealHermSymComplexHerm)" begin
552+
MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
553+
rng, N = MersenneTwister(123), 7
554+
@testset "eigen(::$MT)" for MT in MTs
555+
T = eltype(MT)
556+
ST = _symhermtype(MT)
557+
558+
A = ST(randn(rng, T, N, N))
559+
U = eigvecs(A)
560+
561+
@test _gradtest_symherm(ST, A) do (A)
562+
d, U = eigen(A)
563+
return U * Diagonal(exp.(d)) * U'
591564
end
565+
592566
y = Zygote.pullback(eigen, A)[1]
593567
y2 = eigen(A)
594568
@test y.values y2.values
595569
@test y.vectors y2.vectors
570+
596571
@testset "low rank" begin
597-
U = eigvecs(A)
598-
A2 = Hermitian(U * Diagonal([randn(rng), zeros(N-1)...]) * U')
599-
@test_broken gradtest(reim(collect(A2))...) do a,b
600-
d, U = eigen(Hermitian(complex.(a, b)))
601-
X = U * Diagonal(exp.(d)) * U'
602-
return real.(X) .+ imag.(X)
572+
A2 = Symmetric(U * Diagonal([randn(rng), zeros(N-1)...]) * U')
573+
@test_broken _gradtest_symherm(ST, A2) do (A)
574+
d, U = eigen(A)
575+
return U * Diagonal(exp.(d)) * U'
603576
end
604577
end
605578
end
606579
end
607580

608581
@testset "eigvals(::RealHermSymComplexHerm)" begin
609-
@testset "eigvals(::Symmetric{<:Real})" begin
610-
rng, N = MersenneTwister(123), 7
611-
A = Symmetric(randn(rng, N, N))
612-
@test gradtest(x->eigvals(Symmetric(x)), collect(A))
613-
@test Zygote.pullback(eigvals, A)[1] eigvals(A)
614-
end
582+
MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
583+
rng, N = MersenneTwister(123), 7
584+
@testset "eigvals(::$MT)" for MT in MTs
585+
T = eltype(MT)
586+
ST = _symhermtype(MT)
615587

616-
@testset "eigvals(::Hermitian{<:Real})" begin
617-
rng, N = MersenneTwister(456), 7
618-
A = Hermitian(randn(rng, N, N))
619-
@test gradtest(x->eigvals(Hermitian(x)), collect(A))
588+
A = ST(randn(rng, T, N, N))
589+
@test _gradtest_symherm(A ->eigvals(A), ST, A)
620590
@test Zygote.pullback(eigvals, A)[1] eigvals(A)
621591
end
622-
623-
@testset "eigvals(::Hermitian{<:Complex})" begin
624-
rng, N = MersenneTwister(789), 7
625-
A, B = randn(rng, N, N), randn(rng, N, N)
626-
@test gradtest(A, B) do a,b
627-
c = Hermitian(complex.(a, b))
628-
return eigvals(c)
629-
end
630-
@test Zygote.pullback(eigvals, Hermitian(A))[1] eigvals(Hermitian(A))
631-
end
632592
end
633593

634594
_randmatseries(rng, f, T, n) = rand(rng, T, n, n)
635595
function _randmatseries(rng, ::typeof(atanh), T, n)
636596
return collect(tanh(Hermitian(_randmatseries(rng, tanh, T, n))))
637597
end
638598

639-
_symhermtype(::Type{<:Symmetric}) = Symmetric
640-
_symhermtype(::Type{<:Hermitian}) = Hermitian
641-
642-
function _gradtest_symherm(f, ST, A)
643-
gradtest(_splitreim(collect(A))...) do (args...)
644-
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
645-
return sum(_splitreim(B))
646-
end
647-
end
648-
649599
@testset "Hermitian/Symmetric power series functions" begin
650600
MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
651601
rng, N = MersenneTwister(123), 7

0 commit comments

Comments
 (0)