Skip to content

Commit d45897a

Browse files
authored
Complete revamp (#160)
* Complete revamp * Fix * Revamp complete * ExplicitImports * test preparation
1 parent f45ada3 commit d45897a

27 files changed

+1330
-890
lines changed

Project.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
3-
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
4-
version = "0.6.3"
3+
authors = ["Guillaume Dalle", "Mohamed Tarek"]
4+
version = "0.7.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -12,23 +12,23 @@ LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
1212

1313
[weakdeps]
1414
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
15-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1615
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
16+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[extensions]
1919
ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore"
20-
ImplicitDifferentiationEnzymeExt = "Enzyme"
2120
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
21+
ImplicitDifferentiationZygoteExt = "Zygote"
2222

2323
[compat]
2424
ADTypes = "1.9.0"
2525
ChainRulesCore = "1.25.0"
2626
DifferentiationInterface = "0.6.1"
27-
Enzyme = "0.13.3"
2827
ForwardDiff = "0.10.36"
2928
Krylov = "0.9.6"
3029
LinearAlgebra = "1.10"
3130
LinearOperators = "2.8.0"
31+
Zygote = "0.7.4"
3232
julia = "1.10"
3333

3434
[extras]
@@ -39,7 +39,8 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3939
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4040
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
4141
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
42-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
42+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
43+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
4344
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4445
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4546
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
@@ -53,4 +54,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5354
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5455

5556
[targets]
56-
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "Enzyme", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]
57+
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]

build/ImplicitDifferentiation.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
ImplicitDifferentiation
3+
4+
A Julia package for automatic differentiation of implicit functions.
5+
6+
Its main export is the type [`ImplicitFunction`](@ref).
7+
"""
8+
module ImplicitDifferentiation
9+
10+
using ADTypes: AbstractADType
11+
using DifferentiationInterface:
12+
Constant,
13+
jacobian,
14+
prepare_jacobian,
15+
prepare_pullback,
16+
prepare_pullback_same_point,
17+
prepare_pushforward,
18+
prepare_pushforward_same_point,
19+
pullback!,
20+
pushforward!
21+
using Krylov: gmres
22+
using LinearOperators: LinearOperator
23+
using LinearAlgebra: factorize
24+
25+
include("settings.jl")
26+
include("preparation.jl")
27+
include("implicit_function.jl")
28+
include("execution.jl")
29+
30+
export KrylovLinearSolver
31+
export MatrixRepresentation, OperatorRepresentation
32+
export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation
33+
export ImplicitFunction
34+
35+
end

build/execution.jl

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
const SYMMETRIC = false
2+
const HERMITIAN = false
3+
4+
struct JVP!{F,P,B,X,C}
5+
f::F
6+
prep::P
7+
backend::B
8+
x::X
9+
contexts::C
10+
end
11+
12+
struct VJP!{F,P,B,X,C}
13+
f::F
14+
prep::P
15+
backend::B
16+
x::X
17+
contexts::C
18+
end
19+
20+
function (po::JVP!)(res::AbstractVector, v::AbstractVector)
21+
(; f, backend, x, contexts, prep) = po
22+
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
23+
return res
24+
end
25+
26+
function (po::VJP!)(res::AbstractVector, v::AbstractVector)
27+
(; f, backend, x, contexts, prep) = po
28+
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
29+
return res
30+
end
31+
32+
## A
33+
34+
function build_A(
35+
implicit::ImplicitFunction,
36+
x::AbstractVector,
37+
y::AbstractVector,
38+
z,
39+
args...;
40+
suggested_backend::AbstractADType,
41+
)
42+
return build_A_aux(
43+
implicit.representation, implicit, x, y, z, args...; suggested_backend
44+
)
45+
end
46+
47+
function build_A_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend)
48+
(; conditions, backend, prep_A) = implicit
49+
actual_backend = isnothing(backend) ? suggested_backend : backend
50+
contexts = (Constant(x), Constant(z), map(Constant, args)...)
51+
A = jacobian(Switch12(conditions), prep_A..., actual_backend, y, contexts...)
52+
return factorize(A)
53+
end
54+
55+
function build_A_aux(
56+
::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend
57+
)
58+
(; conditions, backend, prep_A) = implicit
59+
actual_backend = isnothing(backend) ? suggested_backend : backend
60+
contexts = (Constant(x), Constant(z), map(Constant, args)...)
61+
prep_A_same = prepare_pushforward_same_point(
62+
Switch12(conditions), prep_A..., actual_backend, y, (zero(y),), contexts...
63+
)
64+
prod! = JVP!(Switch12(conditions), prep_A_same, actual_backend, y, contexts)
65+
return LinearOperator(
66+
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y)
67+
)
68+
end
69+
70+
## Aᵀ
71+
72+
function build_Aᵀ(
73+
implicit::ImplicitFunction,
74+
x::AbstractVector,
75+
y::AbstractVector,
76+
z,
77+
args...;
78+
suggested_backend::AbstractADType,
79+
)
80+
return build_Aᵀ_aux(
81+
implicit.representation, implicit, x, y, z, args...; suggested_backend
82+
)
83+
end
84+
85+
function build_Aᵀ_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend)
86+
(; conditions, backend, prep_Aᵀ) = implicit
87+
actual_backend = isnothing(backend) ? suggested_backend : backend
88+
contexts = (Constant(x), Constant(z), map(Constant, args)...)
89+
Aᵀ = transpose(
90+
jacobian(Switch12(conditions), prep_Aᵀ..., actual_backend, y, contexts...)
91+
)
92+
return factorize(Aᵀ)
93+
end
94+
95+
function build_Aᵀ_aux(
96+
::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend
97+
)
98+
(; conditions, backend, prep_Aᵀ) = implicit
99+
actual_backend = isnothing(backend) ? suggested_backend : backend
100+
contexts = (Constant(x), Constant(z), map(Constant, args)...)
101+
prep_Aᵀ_same = prepare_pullback_same_point(
102+
Switch12(conditions), prep_Aᵀ..., actual_backend, y, (zero(y),), contexts...
103+
)
104+
prod! = VJP!(Switch12(conditions), prep_Aᵀ_same, actual_backend, y, contexts)
105+
return LinearOperator(
106+
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y)
107+
)
108+
end
109+
110+
## B
111+
112+
function build_B(
113+
implicit::ImplicitFunction,
114+
x::AbstractVector,
115+
y::AbstractVector,
116+
z,
117+
args...;
118+
suggested_backend::AbstractADType,
119+
)
120+
return build_B_aux(
121+
implicit.representation, implicit, x, y, z, args...; suggested_backend
122+
)
123+
end
124+
125+
function build_B_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend)
126+
(; conditions, backend, prep_B) = implicit
127+
actual_backend = isnothing(backend) ? suggested_backend : backend
128+
contexts = (Constant(y), Constant(z), map(Constant, args)...)
129+
return jacobian(conditions, prep_B..., actual_backend, x, contexts...)
130+
end
131+
132+
function build_B_aux(
133+
::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend
134+
)
135+
(; conditions, backend, prep_B) = implicit
136+
actual_backend = isnothing(backend) ? suggested_backend : backend
137+
contexts = (Constant(y), Constant(z), map(Constant, args)...)
138+
prep_B_same = prepare_pushforward_same_point(
139+
conditions, prep_B..., actual_backend, x, (zero(x),), contexts...
140+
)
141+
prod! = JVP!(conditions, prep_B_same, actual_backend, x, contexts)
142+
return LinearOperator(
143+
eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x)
144+
)
145+
end
146+
147+
## Bᵀ
148+
149+
function build_Bᵀ(
150+
implicit::ImplicitFunction,
151+
x::AbstractVector,
152+
y::AbstractVector,
153+
z,
154+
args...;
155+
suggested_backend::AbstractADType,
156+
)
157+
return build_Bᵀ_aux(
158+
implicit.representation, implicit, x, y, z, args...; suggested_backend
159+
)
160+
end
161+
162+
function build_Bᵀ_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend)
163+
(; conditions, backend, prep_Bᵀ) = implicit
164+
actual_backend = isnothing(backend) ? suggested_backend : backend
165+
contexts = (Constant(y), Constant(z), map(Constant, args)...)
166+
return transpose(jacobian(conditions, prep_Bᵀ..., actual_backend, x, contexts...))
167+
end
168+
169+
function build_Bᵀ_aux(
170+
::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend
171+
)
172+
(; conditions, backend, prep_Bᵀ) = implicit
173+
actual_backend = isnothing(backend) ? suggested_backend : backend
174+
contexts = (Constant(y), Constant(z), map(Constant, args)...)
175+
prep_Bᵀ_same = prepare_pullback_same_point(
176+
conditions, prep_Bᵀ..., actual_backend, x, (zero(y),), contexts...
177+
)
178+
prod! = VJP!(conditions, prep_Bᵀ_same, actual_backend, x, contexts)
179+
return LinearOperator(
180+
eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x)
181+
)
182+
end

0 commit comments

Comments
 (0)