416416 return H, back
417417end
418418
419+ @adjoint convert (:: Type{R} , A:: LinearAlgebra.HermOrSym{T,S} ) where {T,S,R<: Array } = convert (R, A),
420+ Δ -> (nothing , convert (S, Δ),)
421+ @adjoint Matrix (A:: LinearAlgebra.HermOrSym{T,S} ) where {T,S} = Matrix (A),
422+ Δ -> (convert (S, Δ),)
423+
419424@adjoint function cholesky (Σ:: Real )
420425 C = cholesky (Σ)
421426 return C, Δ:: NamedTuple -> (Δ. factors[1 , 1 ] / (2 * C. U[1 , 1 ]),)
@@ -451,19 +456,39 @@ end
451456 end
452457end
453458
459+ # Matrix of pairwise difference quotients
460+ Base. @propagate_inbounds function _pairdiffquot (f, i, j, x, fx, dfx, d²fx = nothing )
461+ i == j && return dfx[i]
462+ Δx = x[i] - x[j]
463+ T = real (eltype (x))
464+ if d²fx === nothing
465+ abs (Δx) ≤ sqrt (eps (T)) && return (dfx[i] + dfx[j]) / 2
466+ else
467+ abs (Δx) ≤ eps (T)^ (1 / 3 ) && return dfx[i] - Δx / 2 * d²fx[i]
468+ end
469+ Δfx = fx[i] - fx[j]
470+ return Δfx / Δx
471+ end
472+
473+ Base. @propagate_inbounds function _pairdiffquotmat (f, n, x, fx, dfx, d²fx = nothing )
474+ Δfij = (i, j)-> _pairdiffquot (f, i, j, x, fx, dfx, d²fx)
475+ return Δfij .(Base. OneTo (n), Base. OneTo (n)' )
476+ end
477+
454478# Adjoint based on the Theano implementation, which uses the differential as described
455479# in Brančík, "Matlab programs for matrix exponential function derivative evaluation"
456480@adjoint exp (A:: AbstractMatrix ) = exp (A), function (F̄)
457481 n = size (A, 1 )
458482 E = eigen (A)
459483 w = E. values
460484 ew = exp .(w)
461- X = [i == j ? ew[i] : (ew[i] - ew[j]) / (w[i] - w[j]) for i in 1 : n,j = 1 : n]
485+ X = _pairdiffquotmat (exp, n, w, ew, ew, ew)
462486 V = E. vectors
463487 VF = factorize (V)
464488 Ā = (V * ((VF \ F̄' * V) .* X) / VF)'
465489 return (Ā,)
466490end
491+
467492@adjoint function LinearAlgebra. eigen (A:: LinearAlgebra.RealHermSymComplexHerm )
468493 dU = eigen (A)
469494 return dU, function (Δ)
@@ -489,6 +514,143 @@ end
489514 return d, d̄ -> (U * Diagonal (d̄) * U' ,)
490515end
491516
517+
518+ # Hermitian/Symmetric matrix functions that can be written as power series
519+ _realifydiag! (A:: AbstractArray{<:Real} ) = A
520+ function _realifydiag! (A)
521+ n = LinearAlgebra. checksquare (A)
522+ for i in 1 : n
523+ @inbounds A[i,i] = real (A[i,i])
524+ end
525+ return A
526+ end
527+ @adjoint _realifydiag! (A) = _realifydiag! (A), Δ -> (_realifydiag! (Δ),)
528+
529+ _hasrealdomain (f, x) = true
530+ _hasrealdomain (:: Union{typeof.((acos,asin))...} , x) = all (x -> - 1 ≤ x ≤ 1 , x)
531+ _hasrealdomain (:: typeof (acosh), x) = all (x -> x ≥ 1 , x)
532+ _hasrealdomain (:: Union{typeof.((log,sqrt,^))...} , x) = all (x -> x ≥ 0 , x)
533+
534+ _process_series_eigvals (f, λ) = _hasrealdomain (f, λ) ? λ : complex .(λ)
535+
536+ _process_series_matrix (f, fA, A, fλ) = fA
537+ _process_series_matrix (f, fA, :: LinearAlgebra.HermOrSym{<:Real} , fλ) = Symmetric (fA)
538+ _process_series_matrix (f, fA, :: Hermitian{<:Complex} , :: AbstractVector{<:Real} ) =
539+ Hermitian (_realifydiag! (fA))
540+ _process_series_matrix (:: typeof (^ ), fA, :: Hermitian{<:Real} , fλ) = Hermitian (fA)
541+ _process_series_matrix (:: typeof (^ ), fA, :: Hermitian{<:Real} , :: AbstractVector{<:Complex} ) = fA
542+ _process_series_matrix (:: typeof (^ ), fA, :: Hermitian{<:Complex} , :: AbstractVector{<:Complex} ) = fA
543+
544+ # Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives,
545+ # and function to pull back adjoints to args
546+ function _pullback_series_func_scalar (f, λ, args... )
547+ compλ = _process_series_eigvals (f, λ)
548+ fλ, fback = Zygote. pullback ((x,args... ) -> f .(x, args... ), compλ, args... )
549+ n = length (λ)
550+ return (fλ,
551+ ()-> fback (ones (n))[1 ],
552+ ()-> nothing , # TODO : add 2nd deriv
553+ isempty (args) ? _ -> () : f̄λ -> tail (fback (f̄λ)))
554+ end
555+
556+ function _pullback_series_func_scalar (f:: typeof (^ ), λ, p)
557+ compλ = _process_series_eigvals (f, λ)
558+ r, powλ = isinteger (p) ? (Integer (p), λ) : (p, compλ)
559+ fλ = powλ .^ r
560+ return (fλ,
561+ ()-> conj .(r .* powλ .^ (r - 1 )),
562+ ()-> conj .((r * (r - 1 )) .* powλ .^ (r - 2 )),
563+ f̄λ -> (dot (fλ .* log .(compλ), f̄λ),))
564+ end
565+
566+ function _pullback_series_func_scalar (f:: typeof (exp), λ)
567+ expλ = exp .(λ)
568+ return expλ, ()-> expλ, ()-> expλ, _ -> ()
569+ end
570+
571+ _apply_series_func (f, A, args... ) = f (A, args... )
572+
573+ @adjoint function _apply_series_func (f, A, args... )
574+ hasargs = ! isempty (args)
575+ n = LinearAlgebra. checksquare (A)
576+ λ, U = eigen (A)
577+ fλ, dfthunk, d²fthunk, argsback = _pullback_series_func_scalar (f, λ, args... )
578+ fΛ = Diagonal (fλ)
579+ fA = U * fΛ * U'
580+ Ω = _process_series_matrix (f, fA, A, fλ)
581+ return Ω, function (f̄A)
582+ f̄Λ = U' * f̄A * U
583+ ārgs = hasargs ? argsback (diag (f̄Λ)) : ()
584+ P = _pairdiffquotmat (f, n, λ, conj (fλ), dfthunk (), d²fthunk ())
585+ Ā = U * (P .* f̄Λ) * U'
586+ return (nothing , Ā, ārgs... )
587+ end
588+ end
589+
590+ _hermsympow (A:: Symmetric , p:: Integer ) = LinearAlgebra. sympow (A, p)
591+ _hermsympow (A:: Hermitian , p:: Integer ) = A^ p
592+
593+ @adjoint function _hermsympow (A:: Hermitian , p:: Integer )
594+ if p < 0
595+ B, back = Zygote. pullback (A-> Base. power_by_squaring (inv (A), - p), A)
596+ else
597+ B, back = Zygote. pullback (A-> Base. power_by_squaring (A, p), A)
598+ end
599+ Ω = Hermitian (_realifydiag! (B))
600+ return Ω, function (Ω̄)
601+ B̄ = _hermitian_back (Ω̄, ' U' )
602+ Ā = back (B̄)[1 ]
603+ return (Ā, nothing )
604+ end
605+ end
606+
607+ _pullback (cx:: AContext , :: typeof (^ ), A:: LinearAlgebra.HermOrSym{<:Real} , p:: Integer ) =
608+ _pullback (cx, _hermsympow, A, p)
609+ _pullback (cx:: AContext , :: typeof (^ ), A:: Symmetric{<:Complex} , p:: Integer ) =
610+ _pullback (cx, _hermsympow, A, p)
611+ _pullback (cx:: AContext , :: typeof (^ ), A:: Hermitian{<:Complex} , p:: Integer ) =
612+ _pullback (cx, _hermsympow, A, p)
613+
614+ function _pullback (cx:: AContext ,
615+ f:: typeof (^ ),
616+ A:: LinearAlgebra.RealHermSymComplexHerm ,
617+ p:: Real )
618+ return _pullback (cx, (A, p) -> _apply_series_func (f, A, p), A, p)
619+ end
620+
621+ for func in (:exp , :log , :cos , :sin , :tan , :cosh , :sinh , :tanh , :acos , :asin , :atan , :acosh , :asinh , :atanh , :sqrt )
622+ @eval begin
623+ function _pullback (cx:: AContext ,
624+ f:: typeof ($ func),
625+ A:: LinearAlgebra.RealHermSymComplexHerm )
626+ return _pullback (cx, A -> _apply_series_func (f, A), A)
627+ end
628+ end
629+ end
630+
631+ @adjoint function sincos (A:: LinearAlgebra.RealHermSymComplexHerm )
632+ n = LinearAlgebra. checksquare (A)
633+ λ, U = eigen (A)
634+ sλ, cλ = Buffer (λ), Buffer (λ)
635+ for i in Base. OneTo (n)
636+ @inbounds sλ[i], cλ[i] = sincos (λ[i])
637+ end
638+ sinλ, cosλ = copy (sλ), copy (cλ)
639+ sinA, cosA = U * Diagonal (sinλ) * U' , U * Diagonal (cosλ) * U'
640+ Ω, processback = Zygote. pullback (sinA, cosA) do s,c
641+ return (_process_series_matrix (sin, s, A, λ),
642+ _process_series_matrix (cos, c, A, λ))
643+ end
644+ return Ω, function (Ω̄)
645+ s̄inA, c̄osA = processback (Ω̄)
646+ s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U
647+ PS = _pairdiffquotmat (sin, n, λ, sinλ, cosλ, - sinλ)
648+ PC = _pairdiffquotmat (cos, n, λ, cosλ, - sinλ, - cosλ)
649+ Ā = U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U'
650+ return (Ā,)
651+ end
652+ end
653+
492654Zygote. @adjoint function LinearAlgebra. tr (x:: AbstractMatrix )
493655 # x is a squre matrix checked by tr,
494656 # so we could just use Eye(size(x, 1))
0 commit comments