Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cc5b0df
Add adjoints real functions of hermitian matrices
sethaxen Oct 3, 2019
ee1255e
Make exp work for defective matrices. Fixes #340
sethaxen Oct 3, 2019
69eef19
Support arbitrary functions
sethaxen Oct 5, 2019
e47eecb
Unify Symmetric/Hermitian tests
sethaxen Oct 7, 2019
fe6ed06
Simplify hermitian eigen/eigvals tests
sethaxen Oct 7, 2019
54c1c4b
Revert "Add adjoints for real, conj, and imag for arrays"
sethaxen Oct 8, 2019
a9acf7e
Fix complex derivative for acosh
sethaxen Oct 8, 2019
f4e8871
Add sincos
sethaxen Oct 8, 2019
29db72f
Check if array is square
sethaxen Oct 8, 2019
a513195
Revert "Revert "Add adjoints for real, conj, and imag for arrays""
sethaxen Oct 8, 2019
9e9a370
Fix test for Julia 1.0
sethaxen Oct 8, 2019
75afe81
Clean up notation
sethaxen Oct 8, 2019
bfb2d4d
Add adjoint for ^
sethaxen Oct 8, 2019
16d7271
Add type annotation
sethaxen Oct 9, 2019
b1b88f9
Avoid eigendecomposition for integer power
sethaxen Oct 9, 2019
a1b10ec
Clean up notation
sethaxen Oct 9, 2019
5241fba
Rename functions
sethaxen Oct 9, 2019
b13d6d8
Add custom adjoint for _realifydiag!
sethaxen Oct 9, 2019
2841ed0
Don't pullback through post-processing
sethaxen Oct 9, 2019
74225cf
Factor out computation of scalar func pullback
sethaxen Oct 9, 2019
1deb189
Ensure output types are correct
sethaxen Oct 9, 2019
e377fae
Don't split real and imag when output type variable
sethaxen Oct 9, 2019
df8d3e6
Expand and stabilize tests
sethaxen Oct 10, 2019
c6942ef
Combine functions
sethaxen Oct 10, 2019
378955e
Cover all cases in tests
sethaxen Oct 10, 2019
31cab4e
Add custom pullback for exp
sethaxen Oct 10, 2019
53b7f8e
Fix test for p=0
sethaxen Oct 11, 2019
62b5b22
Improve check and accuracy
sethaxen Oct 11, 2019
d8decf2
Merge branch 'master' into sdaxen/hermitian_funcs
sethaxen Nov 4, 2019
fb95ec7
Add adjoints for converting Hermitian/Symmetric to matrix
sethaxen Nov 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 158 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,19 +438,39 @@ end
end
end

# Matrix of pairwise difference quotients
Base.@propagate_inbounds function _pairdiffquot(f, i, j, x, fx, dfx, d²fx = nothing)
i == j && return dfx[i]
Δx = x[i] - x[j]
T = real(eltype(x))
if d²fx === nothing
abs(Δx) ≤ sqrt(eps(T)) && return (dfx[i] + dfx[j]) / 2
else
abs(Δx) ≤ eps(T)^(1/3) && return dfx[i] - Δx / 2 * d²fx[i]
end
Δfx = fx[i] - fx[j]
return Δfx / Δx
end

Base.@propagate_inbounds function _pairdiffquotmat(f, n, x, fx, dfx, d²fx = nothing)
Δfij = (i, j)->_pairdiffquot(f, i, j, x, fx, dfx, d²fx)
return Δfij.(Base.OneTo(n), Base.OneTo(n)')
end

# Adjoint based on the Theano implementation, which uses the differential as described
# in Brančík, "Matlab programs for matrix exponential function derivative evaluation"
@adjoint exp(A::AbstractMatrix) = exp(A), function(F̄)
n = size(A, 1)
E = eigen(A)
w = E.values
ew = exp.(w)
X = [i==j ? ew[i] : (ew[i]-ew[j])/(w[i]-w[j]) for i in 1:n,j=1:n]
X = _pairdiffquotmat(exp, n, w, ew, ew, ew)
V = E.vectors
VF = factorize(V)
Ā = (V * ((VF \ F̄' * V) .* X) / VF)'
return (Ā,)
end

@adjoint function LinearAlgebra.eigen(A::LinearAlgebra.RealHermSymComplexHerm)
dU = eigen(A)
return dU, function (Δ)
Expand All @@ -476,6 +496,143 @@ end
return d, d̄ -> (U * Diagonal(d̄) * U',)
end


# Hermitian/Symmetric matrix functions that can be written as power series
_realifydiag!(A::AbstractArray{<:Real}) = A
function _realifydiag!(A)
n = LinearAlgebra.checksquare(A)
for i in 1:n
@inbounds A[i,i] = real(A[i,i])
end
return A
end
@adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),)

_hasrealdomain(f, x) = true
_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 ≤ x ≤ 1, x)
_hasrealdomain(::typeof(acosh), x) = all(x -> x ≥ 1, x)
_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x ≥ 0, x)

_process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ)

_process_series_matrix(f, fA, A, fλ) = fA
_process_series_matrix(f, fA, ::LinearAlgebra.HermOrSym{<:Real}, fλ) = Symmetric(fA)
_process_series_matrix(f, fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Real}) =
Hermitian(_realifydiag!(fA))
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, fλ) = Hermitian(fA)
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, ::AbstractVector{<:Complex}) = fA
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Complex}) = fA

# Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives,
# and function to pull back adjoints to args
function _pullback_series_func_scalar(f, λ, args...)
compλ = _process_series_eigvals(f, λ)
fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...)
n = length(λ)
return (fλ,
()->fback(ones(n))[1],
()->nothing, # TODO: add 2nd deriv
isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ)))
end

function _pullback_series_func_scalar(f::typeof(^), λ, p)
compλ = _process_series_eigvals(f, λ)
r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ)
fλ = powλ .^ r
return (fλ,
()->conj.(r .* powλ .^ (r - 1)),
()->conj.((r * (r - 1)) .* powλ .^ (r - 2)),
f̄λ -> (dot(fλ .* log.(compλ), f̄λ),))
end

function _pullback_series_func_scalar(f::typeof(exp), λ)
expλ = exp.(λ)
return expλ, ()->expλ, ()->expλ, _ -> ()
end

_apply_series_func(f, A, args...) = f(A, args...)

@adjoint function _apply_series_func(f, A, args...)
hasargs = !isempty(args)
n = LinearAlgebra.checksquare(A)
λ, U = eigen(A)
fλ, dfthunk, d²fthunk, argsback = _pullback_series_func_scalar(f, λ, args...)
fΛ = Diagonal(fλ)
fA = U * fΛ * U'
Ω = _process_series_matrix(f, fA, A, fλ)
return Ω, function (f̄A)
f̄Λ = U' * f̄A * U
ārgs = hasargs ? argsback(diag(f̄Λ)) : ()
P = _pairdiffquotmat(f, n, λ, conj(fλ), dfthunk(), d²fthunk())
Ā = U * (P .* f̄Λ) * U'
return (nothing, Ā, ārgs...)
end
end

_hermsympow(A::Symmetric, p::Integer) = LinearAlgebra.sympow(A, p)
_hermsympow(A::Hermitian, p::Integer) = A^p

@adjoint function _hermsympow(A::Hermitian, p::Integer)
if p < 0
B, back = Zygote.pullback(A->Base.power_by_squaring(inv(A), -p), A)
else
B, back = Zygote.pullback(A->Base.power_by_squaring(A, p), A)
end
Ω = Hermitian(_realifydiag!(B))
return Ω, function (Ω̄)
B̄ = _hermitian_back(Ω̄, 'U')
Ā = back(B̄)[1]
return (Ā, nothing)
end
end

_pullback(cx::AContext, ::typeof(^), A::LinearAlgebra.HermOrSym{<:Real}, p::Integer) =
_pullback(cx, _hermsympow, A, p)
_pullback(cx::AContext, ::typeof(^), A::Symmetric{<:Complex}, p::Integer) =
_pullback(cx, _hermsympow, A, p)
_pullback(cx::AContext, ::typeof(^), A::Hermitian{<:Complex}, p::Integer) =
_pullback(cx, _hermsympow, A, p)

function _pullback(cx::AContext,
f::typeof(^),
A::LinearAlgebra.RealHermSymComplexHerm,
p::Real)
return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p)
end

for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt)
@eval begin
function _pullback(cx::AContext,
f::typeof($func),
A::LinearAlgebra.RealHermSymComplexHerm)
return _pullback(cx, A -> _apply_series_func(f, A), A)
end
end
end

@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm)
n = LinearAlgebra.checksquare(A)
λ, U = eigen(A)
sλ, cλ = Buffer(λ), Buffer(λ)
for i in Base.OneTo(n)
@inbounds sλ[i], cλ[i] = sincos(λ[i])
end
sinλ, cosλ = copy(sλ), copy(cλ)
sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U'
Ω, processback = Zygote.pullback(sinA, cosA) do s,c
return (_process_series_matrix(sin, s, A, λ),
_process_series_matrix(cos, c, A, λ))
end
return Ω, function (Ω̄)
s̄inA, c̄osA = processback(Ω̄)
s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U
PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ)
PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ)
Ā = U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U'
return (Ā,)
end
end

Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix)
# x is a squre matrix checked by tr,
# so we could just use Eye(size(x, 1))
Expand Down
3 changes: 3 additions & 0 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ end
(s, c), ((s̄, c̄),) -> (s̄*c - c̄*s,)
end

@adjoint acosh(x::Complex) =
acosh(x), Δ -> (Δ * conj(inv(sqrt(x - 1) * sqrt(x + 1))),)

@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b))

@nograd floor, ceil, trunc, round, hash
Expand Down
Loading