Skip to content

Commit 028d3fb

Browse files
authored
More lenient iterative linear solver (#119)
* More lenient iterative linear solver * Typo * Fix bools * Fix tests and add docs
1 parent f58c705 commit 028d3fb

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"

src/linear_solver.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,36 @@ abstract type AbstractLinearSolver end
1616
"""
1717
IterativeLinearSolver
1818
19-
An implementation of `AbstractLinearSolver` using `Krylov.gmres`.
19+
An implementation of `AbstractLinearSolver` using `Krylov.gmres`, set as the default for `ImplicitFunction`.
2020
2121
# Fields
2222
23-
- `verbose::Bool`: Whether to throw a warning when the solver fails (defaults to `true`)
23+
- `verbose::Bool`: Whether to display a warning when the solver fails and returns `NaN`s (defaults to `true`)
24+
- `accept_inconsistent::Bool`: Whether to accept approximate least squares solutions for inconsistent systems, or fail and return `NaN`s (defaults to `false`)
25+
26+
!!! note
27+
If you find that your implicit gradients contains `NaN`s, try using this solver with `accept_inconsistent=true`.
28+
However, beware that the implicit function theorem does not cover the case of inconsistent linear systems `AJ = B`, so it is unclear what the result will mean.
2429
"""
2530
Base.@kwdef struct IterativeLinearSolver <: AbstractLinearSolver
2631
verbose::Bool = true
32+
accept_inconsistent::Bool = false
2733
end
2834

2935
presolve(::IterativeLinearSolver, A, y) = A
3036

3137
function solve(sol::IterativeLinearSolver, A, b)
3238
x, stats = gmres(A, b)
33-
if !stats.solved || stats.inconsistent
34-
sol.verbose && @warn "IterativeLinearSolver failed, result contains NaNs"
39+
if sol.accept_inconsistent
40+
success = stats.solved || stats.inconsistent
41+
else
42+
success = stats.solved && !stats.inconsistent
43+
end
44+
if !success
45+
if sol.verbose
46+
@warn "IterativeLinearSolver failed, result contains NaNs"
47+
@show stats
48+
end
3549
x .= NaN
3650
end
3751
return x

test/errors.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ end
2626
@testset verbose = true "Derivative NaNs" begin
2727
x = zeros(Float32, 2)
2828
linear_solvers = (
29-
IterativeLinearSolver(; verbose=false), DirectLinearSolver(; verbose=false)
29+
IterativeLinearSolver(; verbose=false), #
30+
IterativeLinearSolver(; verbose=false, accept_inconsistent=true), #
31+
DirectLinearSolver(; verbose=false), #
3032
)
33+
function should_give_nan(linear_solver)
34+
return linear_solver isa DirectLinearSolver || !linear_solver.accept_inconsistent
35+
end
3136

3237
@testset "Infinite derivative" begin
3338
f = x -> sqrt.(x) # nondifferentiable at 0
@@ -37,8 +42,10 @@ end
3742
implicit = ImplicitFunction(f, c; linear_solver)
3843
J1 = ForwardDiff.jacobian(implicit, x)
3944
J2 = Zygote.jacobian(implicit, x)[1]
40-
@test all(isnan, J1) && eltype(J1) == Float32
41-
@test all(isnan, J2) && eltype(J2) == Float32
45+
@test all(isnan, J1) == should_give_nan(linear_solver)
46+
@test all(isnan, J2) == should_give_nan(linear_solver)
47+
@test eltype(J1) == Float32
48+
@test eltype(J2) == Float32
4249
end
4350
end
4451
end
@@ -51,8 +58,10 @@ end
5158
implicit = ImplicitFunction(f, c; linear_solver)
5259
J1 = ForwardDiff.jacobian(implicit, x)
5360
J2 = Zygote.jacobian(implicit, x)[1]
54-
@test all(isnan, J1) && eltype(J1) == Float32
55-
@test all(isnan, J2) && eltype(J2) == Float32
61+
@test all(isnan, J1) == should_give_nan(linear_solver)
62+
@test all(isnan, J2) == should_give_nan(linear_solver)
63+
@test eltype(J1) == Float32
64+
@test eltype(J2) == Float32
5665
end
5766
end
5867
end

0 commit comments

Comments
 (0)