Skip to content

Commit 287f704

Browse files
authored
Try #355:
2 parents 6ac5b5b + fb95ec7 commit 287f704

File tree

3 files changed

+403
-70
lines changed

3 files changed

+403
-70
lines changed

src/lib/array.jl

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ end
416416
return H, back
417417
end
418418

419+
@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A),
420+
Δ -> (nothing, convert(S, Δ),)
421+
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
422+
Δ -> (convert(S, Δ),)
423+
419424
@adjoint function cholesky::Real)
420425
C = cholesky(Σ)
421426
return C, Δ::NamedTuple->.factors[1, 1] / (2 * C.U[1, 1]),)
@@ -451,19 +456,39 @@ end
451456
end
452457
end
453458

459+
# Matrix of pairwise difference quotients
460+
Base.@propagate_inbounds function _pairdiffquot(f, i, j, x, fx, dfx, d²fx = nothing)
461+
i == j && return dfx[i]
462+
Δx = x[i] - x[j]
463+
T = real(eltype(x))
464+
if d²fx === nothing
465+
abs(Δx) sqrt(eps(T)) && return (dfx[i] + dfx[j]) / 2
466+
else
467+
abs(Δx) eps(T)^(1/3) && return dfx[i] - Δx / 2 * d²fx[i]
468+
end
469+
Δfx = fx[i] - fx[j]
470+
return Δfx / Δx
471+
end
472+
473+
Base.@propagate_inbounds function _pairdiffquotmat(f, n, x, fx, dfx, d²fx = nothing)
474+
Δfij = (i, j)->_pairdiffquot(f, i, j, x, fx, dfx, d²fx)
475+
return Δfij.(Base.OneTo(n), Base.OneTo(n)')
476+
end
477+
454478
# Adjoint based on the Theano implementation, which uses the differential as described
455479
# in Brančík, "Matlab programs for matrix exponential function derivative evaluation"
456480
@adjoint exp(A::AbstractMatrix) = exp(A), function(F̄)
457481
n = size(A, 1)
458482
E = eigen(A)
459483
w = E.values
460484
ew = exp.(w)
461-
X = [i==j ? ew[i] : (ew[i]-ew[j])/(w[i]-w[j]) for i in 1:n,j=1:n]
485+
X = _pairdiffquotmat(exp, n, w, ew, ew, ew)
462486
V = E.vectors
463487
VF = factorize(V)
464488
= (V * ((VF \' * V) .* X) / VF)'
465489
return (Ā,)
466490
end
491+
467492
@adjoint function LinearAlgebra.eigen(A::LinearAlgebra.RealHermSymComplexHerm)
468493
dU = eigen(A)
469494
return dU, function (Δ)
@@ -489,6 +514,143 @@ end
489514
return d, d̄ -> (U * Diagonal(d̄) * U',)
490515
end
491516

517+
518+
# Hermitian/Symmetric matrix functions that can be written as power series
519+
_realifydiag!(A::AbstractArray{<:Real}) = A
520+
function _realifydiag!(A)
521+
n = LinearAlgebra.checksquare(A)
522+
for i in 1:n
523+
@inbounds A[i,i] = real(A[i,i])
524+
end
525+
return A
526+
end
527+
@adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),)
528+
529+
_hasrealdomain(f, x) = true
530+
_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 x 1, x)
531+
_hasrealdomain(::typeof(acosh), x) = all(x -> x 1, x)
532+
_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x 0, x)
533+
534+
_process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ)
535+
536+
_process_series_matrix(f, fA, A, fλ) = fA
537+
_process_series_matrix(f, fA, ::LinearAlgebra.HermOrSym{<:Real}, fλ) = Symmetric(fA)
538+
_process_series_matrix(f, fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Real}) =
539+
Hermitian(_realifydiag!(fA))
540+
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, fλ) = Hermitian(fA)
541+
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, ::AbstractVector{<:Complex}) = fA
542+
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Complex}) = fA
543+
544+
# Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives,
545+
# and function to pull back adjoints to args
546+
function _pullback_series_func_scalar(f, λ, args...)
547+
compλ = _process_series_eigvals(f, λ)
548+
fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...)
549+
n = length(λ)
550+
return (fλ,
551+
()->fback(ones(n))[1],
552+
()->nothing, # TODO: add 2nd deriv
553+
isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ)))
554+
end
555+
556+
function _pullback_series_func_scalar(f::typeof(^), λ, p)
557+
compλ = _process_series_eigvals(f, λ)
558+
r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ)
559+
= powλ .^ r
560+
return (fλ,
561+
()->conj.(r .* powλ .^ (r - 1)),
562+
()->conj.((r * (r - 1)) .* powλ .^ (r - 2)),
563+
f̄λ -> (dot(fλ .* log.(compλ), f̄λ),))
564+
end
565+
566+
function _pullback_series_func_scalar(f::typeof(exp), λ)
567+
expλ = exp.(λ)
568+
return expλ, ()->expλ, ()->expλ, _ -> ()
569+
end
570+
571+
_apply_series_func(f, A, args...) = f(A, args...)
572+
573+
@adjoint function _apply_series_func(f, A, args...)
574+
hasargs = !isempty(args)
575+
n = LinearAlgebra.checksquare(A)
576+
λ, U = eigen(A)
577+
fλ, dfthunk, d²fthunk, argsback = _pullback_series_func_scalar(f, λ, args...)
578+
= Diagonal(fλ)
579+
fA = U ** U'
580+
Ω = _process_series_matrix(f, fA, A, fλ)
581+
return Ω, function (f̄A)
582+
f̄Λ = U' * f̄A * U
583+
ārgs = hasargs ? argsback(diag(f̄Λ)) : ()
584+
P = _pairdiffquotmat(f, n, λ, conj(fλ), dfthunk(), d²fthunk())
585+
= U * (P .* f̄Λ) * U'
586+
return (nothing, Ā, ārgs...)
587+
end
588+
end
589+
590+
_hermsympow(A::Symmetric, p::Integer) = LinearAlgebra.sympow(A, p)
591+
_hermsympow(A::Hermitian, p::Integer) = A^p
592+
593+
@adjoint function _hermsympow(A::Hermitian, p::Integer)
594+
if p < 0
595+
B, back = Zygote.pullback(A->Base.power_by_squaring(inv(A), -p), A)
596+
else
597+
B, back = Zygote.pullback(A->Base.power_by_squaring(A, p), A)
598+
end
599+
Ω = Hermitian(_realifydiag!(B))
600+
return Ω, function (Ω̄)
601+
= _hermitian_back(Ω̄, 'U')
602+
= back(B̄)[1]
603+
return (Ā, nothing)
604+
end
605+
end
606+
607+
_pullback(cx::AContext, ::typeof(^), A::LinearAlgebra.HermOrSym{<:Real}, p::Integer) =
608+
_pullback(cx, _hermsympow, A, p)
609+
_pullback(cx::AContext, ::typeof(^), A::Symmetric{<:Complex}, p::Integer) =
610+
_pullback(cx, _hermsympow, A, p)
611+
_pullback(cx::AContext, ::typeof(^), A::Hermitian{<:Complex}, p::Integer) =
612+
_pullback(cx, _hermsympow, A, p)
613+
614+
function _pullback(cx::AContext,
615+
f::typeof(^),
616+
A::LinearAlgebra.RealHermSymComplexHerm,
617+
p::Real)
618+
return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p)
619+
end
620+
621+
for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt)
622+
@eval begin
623+
function _pullback(cx::AContext,
624+
f::typeof($func),
625+
A::LinearAlgebra.RealHermSymComplexHerm)
626+
return _pullback(cx, A -> _apply_series_func(f, A), A)
627+
end
628+
end
629+
end
630+
631+
@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm)
632+
n = LinearAlgebra.checksquare(A)
633+
λ, U = eigen(A)
634+
sλ, cλ = Buffer(λ), Buffer(λ)
635+
for i in Base.OneTo(n)
636+
@inbounds sλ[i], cλ[i] = sincos(λ[i])
637+
end
638+
sinλ, cosλ = copy(sλ), copy(cλ)
639+
sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U'
640+
Ω, processback = Zygote.pullback(sinA, cosA) do s,c
641+
return (_process_series_matrix(sin, s, A, λ),
642+
_process_series_matrix(cos, c, A, λ))
643+
end
644+
return Ω, function (Ω̄)
645+
s̄inA, c̄osA = processback(Ω̄)
646+
s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U
647+
PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ)
648+
PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ)
649+
= U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U'
650+
return (Ā,)
651+
end
652+
end
653+
492654
Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix)
493655
# x is a squre matrix checked by tr,
494656
# so we could just use Eye(size(x, 1))

src/lib/number.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ end
4949
(s, c), ((s̄, c̄),) -> (s̄*c -*s,)
5050
end
5151

52+
@adjoint acosh(x::Complex) =
53+
acosh(x), Δ ->* conj(inv(sqrt(x - 1) * sqrt(x + 1))),)
54+
5255
@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, -* a // b // b))
5356

5457
@nograd floor, ceil, trunc, round, hash

0 commit comments

Comments
 (0)