@@ -538,114 +538,64 @@ end
538538 end
539539end
540540
541- @testset " eigen(::RealHermSymComplexHerm)" begin
542- @testset " eigen(::Symmetric{<:Real})" begin
543- rng, N = MersenneTwister (123 ), 7
544- A = Symmetric (randn (rng, N, N))
545- @test gradtest (collect (A)) do (x)
546- d, Q = eigen (Symmetric (x))
547- return Q * Diagonal (exp .(d)) * transpose (Q)
548- end
549- y = Zygote. pullback (eigen, A)[1 ]
550- y2 = eigen (A)
551- @test y. values ≈ y2. values
552- @test y. vectors ≈ y2. vectors
553- @testset " low rank" begin
554- U = eigvecs (A)
555- A2 = Symmetric (U * Diagonal ([randn (rng), zeros (N- 1 )... ]) * U' )
556- @test_broken gradtest (collect (A2)) do (x)
557- d, Q = eigen (Symmetric (x))
558- return Q * Diagonal (exp .(d)) * transpose (Q)
559- end
560- end
561- end
541+ _symhermtype (:: Type{<:Symmetric} ) = Symmetric
542+ _symhermtype (:: Type{<:Hermitian} ) = Hermitian
562543
563- @testset " eigen(::Hermitian{<:Real})" begin
564- rng, N = MersenneTwister (456 ), 7
565- A = Hermitian (randn (rng, N, N))
566- @test gradtest (collect (A)) do (x)
567- d, Q = eigen (Hermitian (x))
568- return Q * Diagonal (exp .(d)) * transpose (Q)
569- end
570- y = Zygote. pullback (eigen, A)[1 ]
571- y2 = eigen (A)
572- @test y. values ≈ y2. values
573- @test y. vectors ≈ y2. vectors
574- @testset " low rank" begin
575- U = eigvecs (A)
576- A2 = Hermitian (U * Diagonal ([randn (rng), zeros (N- 1 )... ]) * U' )
577- @test_broken gradtest (collect (A2)) do (x)
578- d, Q = eigen (Hermitian (x))
579- return Q * Diagonal (exp .(d)) * transpose (Q)
580- end
581- end
544+ function _gradtest_symherm (f, ST, A)
545+ gradtest (_splitreim (collect (A))... ) do (args... )
546+ B = f (ST (_joinreim (_dropimaggrad .(args)... )))
547+ return sum (_splitreim (B))
582548 end
549+ end
583550
584- @testset " eigen(::Hermitian{<:Complex})" begin
585- rng, N = MersenneTwister (789 ), 7
586- A = Hermitian (randn (rng, ComplexF64, N, N))
587- @test gradtest (reim (collect (A))... ) do a,b
588- d, U = eigen (Hermitian (complex .(a, b)))
589- X = U * Diagonal (exp .(d)) * U'
590- return real .(X) .+ imag .(X)
551+ @testset " eigen(::RealHermSymComplexHerm)" begin
552+ MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
553+ rng, N = MersenneTwister (123 ), 7
554+ @testset " eigen(::$MT )" for MT in MTs
555+ T = eltype (MT)
556+ ST = _symhermtype (MT)
557+
558+ A = ST (randn (rng, T, N, N))
559+ U = eigvecs (A)
560+
561+ @test _gradtest_symherm (ST, A) do (A)
562+ d, U = eigen (A)
563+ return U * Diagonal (exp .(d)) * U'
591564 end
565+
592566 y = Zygote. pullback (eigen, A)[1 ]
593567 y2 = eigen (A)
594568 @test y. values ≈ y2. values
595569 @test y. vectors ≈ y2. vectors
570+
596571 @testset " low rank" begin
597- U = eigvecs (A)
598- A2 = Hermitian (U * Diagonal ([randn (rng), zeros (N- 1 )... ]) * U' )
599- @test_broken gradtest (reim (collect (A2))... ) do a,b
600- d, U = eigen (Hermitian (complex .(a, b)))
601- X = U * Diagonal (exp .(d)) * U'
602- return real .(X) .+ imag .(X)
572+ A2 = Symmetric (U * Diagonal ([randn (rng), zeros (N- 1 )... ]) * U' )
573+ @test_broken _gradtest_symherm (ST, A2) do (A)
574+ d, U = eigen (A)
575+ return U * Diagonal (exp .(d)) * U'
603576 end
604577 end
605578 end
606579end
607580
608581@testset " eigvals(::RealHermSymComplexHerm)" begin
609- @testset " eigvals(::Symmetric{<:Real})" begin
610- rng, N = MersenneTwister (123 ), 7
611- A = Symmetric (randn (rng, N, N))
612- @test gradtest (x-> eigvals (Symmetric (x)), collect (A))
613- @test Zygote. pullback (eigvals, A)[1 ] ≈ eigvals (A)
614- end
582+ MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
583+ rng, N = MersenneTwister (123 ), 7
584+ @testset " eigvals(::$MT )" for MT in MTs
585+ T = eltype (MT)
586+ ST = _symhermtype (MT)
615587
616- @testset " eigvals(::Hermitian{<:Real})" begin
617- rng, N = MersenneTwister (456 ), 7
618- A = Hermitian (randn (rng, N, N))
619- @test gradtest (x-> eigvals (Hermitian (x)), collect (A))
588+ A = ST (randn (rng, T, N, N))
589+ @test _gradtest_symherm (A -> eigvals (A), ST, A)
620590 @test Zygote. pullback (eigvals, A)[1 ] ≈ eigvals (A)
621591 end
622-
623- @testset " eigvals(::Hermitian{<:Complex})" begin
624- rng, N = MersenneTwister (789 ), 7
625- A, B = randn (rng, N, N), randn (rng, N, N)
626- @test gradtest (A, B) do a,b
627- c = Hermitian (complex .(a, b))
628- return eigvals (c)
629- end
630- @test Zygote. pullback (eigvals, Hermitian (A))[1 ] ≈ eigvals (Hermitian (A))
631- end
632592end
633593
634594_randmatseries (rng, f, T, n) = rand (rng, T, n, n)
635595function _randmatseries (rng, :: typeof (atanh), T, n)
636596 return collect (tanh (Hermitian (_randmatseries (rng, tanh, T, n))))
637597end
638598
639- _symhermtype (:: Type{<:Symmetric} ) = Symmetric
640- _symhermtype (:: Type{<:Hermitian} ) = Hermitian
641-
642- function _gradtest_symherm (f, ST, A)
643- gradtest (_splitreim (collect (A))... ) do (args... )
644- B = f (ST (_joinreim (_dropimaggrad .(args)... )))
645- return sum (_splitreim (B))
646- end
647- end
648-
649599@testset " Hermitian/Symmetric power series functions" begin
650600 MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
651601 rng, N = MersenneTwister (123 ), 7
0 commit comments