@@ -25,6 +25,15 @@ gradcheck(f, xs...) =
2525gradtest (f, xs:: AbstractArray... ) = gradcheck ((xs... ) -> sum (sin .(f (xs... ))), xs... )
2626gradtest (f, dims... ) = gradtest (f, rand .(Float64, dims)... )
2727
28+ # utilities for using gradcheck with complex matrices
29+ _splitreim (A) = (real (A),)
30+ _splitreim (A:: AbstractArray{<:Complex} ) = reim (A)
31+
32+ _joinreim (A, B) = complex .(A, B)
33+ _joinreim (A) = A
34+
35+ _dropimaggrad (A) = Zygote. hook (real, A)
36+
2837Random. seed! (0 )
2938
3039@test gradient (// , 2 , 3 ) === (1 // 3 , - 2 // 9 )
@@ -627,80 +636,43 @@ function _randmatseries(rng, ::typeof(atanh), T, n)
627636 return collect (tanh (Hermitian (_randmatseries (rng, tanh, T, n))))
628637end
629638
639+ _symhermtype (:: Type{<:Symmetric} ) = Symmetric
640+ _symhermtype (:: Type{<:Hermitian} ) = Hermitian
641+
642+ function _gradtest_symherm (f, ST, A)
643+ gradtest (_splitreim (collect (A))... ) do (args... )
644+ B = f (ST (_joinreim (_dropimaggrad .(args)... )))
645+ return sum (_splitreim (B))
646+ end
647+ end
648+
630649@testset " Hermitian/Symmetric power series functions" begin
650+ MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
651+ rng, N = MersenneTwister (123 ), 7
631652 @testset " $func (::RealHermSymComplexHerm)" for func in (:exp , :log , :cos , :sin , :tan , :cosh , :sinh , :tanh , :acos , :asin , :atan , :acosh , :asinh , :atanh , :sqrt )
632653 f = eval (func)
633- @testset " $func (::Symmetric{<:Real})" begin
634- rng, N = MersenneTwister (123 ), 7
635- A = Symmetric (_randmatseries (rng, f, Float64, N))
636- @test gradtest (x-> sum (reim (f (Symmetric (Zygote. hook (real, x))))), collect (A))
637- y = Zygote. pullback (f, A)[1 ]
638- y2 = f (A)
639- @test y ≈ y2
640- @testset " similar eigenvalues" begin
641- λ, U = eigen (A)
642- λ[1 ] = λ[3 ] + sqrt (eps (eltype (λ))) / 10
643- A2 = U * Diagonal (λ) * U'
644- @test gradtest (x-> sum (reim (f (Symmetric (Zygote. hook (real, x))))), collect (A2))
645- end
646- if f ∉ (log, sqrt) # only defined for invertible matrices
647- @testset " low rank" begin
648- U = eigvecs (A)
649- A2 = U * Diagonal ([rand (rng), zeros (N- 1 )... ]) * U'
650- @test gradtest (x-> sum (reim (f (Symmetric (Zygote. hook (real, x))))), collect (A2))
651- end
652- end
653- end
654+ @testset " $func (::$MT )" for MT in MTs
655+ T = eltype (MT)
656+ ST = _symhermtype (MT)
657+ A = ST (_randmatseries (rng, f, T, N))
658+ λ, U = eigen (A)
659+
660+ @test _gradtest_symherm (f, ST, A)
654661
655- @testset " $func (::Hermitian{<:Real})" begin
656- rng, N = MersenneTwister (456 ), 7
657- A = Hermitian (_randmatseries (rng, f, Float64, N))
658- @test gradtest (x-> sum (reim (f (Hermitian (Zygote. hook (real, x))))), collect (A))
659662 y = Zygote. pullback (f, A)[1 ]
660663 y2 = f (A)
661664 @test y ≈ y2
665+
662666 @testset " similar eigenvalues" begin
663- λ, U = eigen (A)
664667 λ[1 ] = λ[3 ] + sqrt (eps (eltype (λ))) / 10
665668 A2 = U * Diagonal (λ) * U'
666- @test gradtest (x -> sum ( reim ( f ( Hermitian (Zygote . hook (real, x))))), collect (A2) )
669+ @test _gradtest_symherm (f, ST, A2 )
667670 end
668- if f ∉ (log, sqrt) # only defined for invertible matrices
669- @testset " low rank" begin
670- U = eigvecs (A)
671- A2 = U * Diagonal ([rand (rng), zeros (N- 1 )... ]) * U'
672- @test gradtest (x-> sum (reim (f (Hermitian (Zygote. hook (real, x))))), collect (A2))
673- end
674- end
675- end
676671
677- @testset " $func (::Hermitian{<:Complex})" begin
678- rng, N = MersenneTwister (789 ), 7
679- A = Hermitian (_randmatseries (rng, f, Complex{Float64}, N))
680- @test gradtest (reim (collect (A))... ) do a,b
681- B = f (Hermitian (complex .(a, b)))
682- return real .(B) .+ 2 .* imag .(B)
683- end
684- y = Zygote. pullback (f, A)[1 ]
685- y2 = f (A)
686- @test y ≈ y2
687- @testset " similar eigenvalues" begin
688- λ, U = eigen (A)
689- λ[1 ] = λ[3 ] + sqrt (eps (eltype (λ))) / 10
690- A2 = Hermitian (U * Diagonal (λ) * U' )
691- @test gradtest (reim (collect (A2))... ) do a,b
692- B = f (Hermitian (complex .(a, b)))
693- return real .(B) .+ 2 .* imag .(B)
694- end
695- end
696672 if f ∉ (log, sqrt) # only defined for invertible matrices
697673 @testset " low rank" begin
698- U = eigvecs (A)
699- A2 = U * Diagonal ([rand (rng), zeros (N- 1 )... ]) * U'
700- @test gradtest (reim (collect (A2))... ) do a,b
701- B = f (Hermitian (complex .(a, b)))
702- return real .(B) .+ 2 .* imag .(B)
703- end
674+ A3 = U * Diagonal ([rand (rng), zeros (N- 1 )... ]) * U'
675+ @test _gradtest_symherm (f, ST, A3)
704676 end
705677 end
706678 end
0 commit comments