Skip to content

Commit cf5ab9d

Browse files
authored
Fix flattening bug (#52)
1 parent 8da575b commit cf5ab9d

File tree

6 files changed

+87
-25
lines changed

6 files changed

+87
-25
lines changed

CITATION.bib

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ @misc{ImplicitDifferentiation.jl
22
author = {Guillaume Dalle, Mohamed Tarek and contributors},
33
title = {ImplicitDifferentiation.jl},
44
url = {https://github.com/gdalle/ImplicitDifferentiation.jl},
5-
version = {v0.4.3},
5+
version = {v0.4.4},
66
year = {2023},
77
month = {5}
88
}

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
4-
version = "0.4.3"
4+
version = "0.4.4"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"

examples/0_basic.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -162,25 +162,3 @@ h = rand(2)
162162
J_Z(t) = Zygote.jacobian(first implicit2, x .+ t .* h)[1]
163163
ForwardDiff.derivative(J_Z, 0) Diagonal((-0.25 .* h) ./ (x .^ 1.5))
164164
@test ForwardDiff.derivative(J_Z, 0) Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src
165-
166-
# The following tests are not included in the docs #src
167-
168-
X = rand(2, 3, 4) #src
169-
JJ = Diagonal(0.5 ./ sqrt.(vec(X))) #src
170-
@test (first implicit)(X) sqrt.(X) #src
171-
@test ForwardDiff.jacobian(first implicit, X) JJ #src
172-
@test Zygote.jacobian(first implicit, X)[1] JJ #src
173-
174-
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities #src
175-
@testset verbose = true "ChainRulesTestUtils.jl" begin #src
176-
@test_skip test_rrule(implicit, x) #src
177-
@test_skip test_rrule(implicit, X) #src
178-
end #src
179-
180-
x_and_dx = [ForwardDiff.Dual(x[i], (0, 0)) for i in eachindex(x)] #src
181-
@inferred implicit(x_and_dx) #src
182-
183-
rc = Zygote.ZygoteRuleConfig() #src
184-
_, pullback = @inferred rrule(rc, implicit, x) #src
185-
dy, dz = zero(implicit(x)[1]), 0
186-
@inferred pullback((dy, dz))

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ function (implicit::ImplicitFunction)(
4040
end
4141

4242
y_and_dy = let y = y, dy = dy
43-
map(eachindex(y)) do i
43+
y_and_dy_vec = map(eachindex(y)) do i
4444
Dual{T}(y[i], Partials(ntuple(k -> dy[k][i], Val(N))))
4545
end
46+
reshape(y_and_dy_vec, size(y))
4647
end
4748
return y_and_dy, z
4849
end

test/misc.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using ChainRulesCore
2+
using ChainRulesTestUtils
3+
using ForwardDiff
4+
using ImplicitDifferentiation
5+
using JET
6+
using LinearAlgebra
7+
using Random
8+
using Test
9+
using Zygote
10+
11+
Random.seed!(63);
12+
13+
function mysqrt(x::AbstractArray)
14+
a = [0.0]
15+
a[1] = first(x)
16+
return sqrt.(x)
17+
end
18+
19+
forward(x) = mysqrt(x), 0
20+
conditions(x, y, z) = y .^ 2 .- x
21+
implicit = ImplicitFunction(forward, conditions)
22+
23+
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities
24+
@testset verbose = true "ChainRulesTestUtils.jl" begin
25+
@test_skip test_rrule(implicit, x)
26+
@test_skip test_rrule(implicit, X)
27+
end
28+
29+
@testset verbose = true "Vectors" begin
30+
x = rand(2)
31+
y, _ = implicit(x)
32+
J = Diagonal(0.5 ./ sqrt.(x))
33+
34+
@testset "Exactness" begin
35+
@test (first implicit)(x) sqrt.(x)
36+
@test ForwardDiff.jacobian(first implicit, x) J
37+
@test Zygote.jacobian(first implicit, x)[1] J
38+
end
39+
40+
@testset verbose = true "Forward inference" begin
41+
x_and_dx = ForwardDiff.Dual.(x, ((0, 0),))
42+
@test (@inferred implicit(x_and_dx)) == implicit(x_and_dx)
43+
y_and_dy, _ = implicit(x_and_dx)
44+
@test size(y_and_dy) == size(y)
45+
end
46+
@testset "Reverse type inference" begin
47+
_, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, x)
48+
dy, dz = zero(implicit(x)[1]), 0
49+
@test (@inferred pullback((dy, dz))) == pullback((dy, dz))
50+
_, dx = pullback((dy, dz))
51+
@test size(dx) == size(x)
52+
end
53+
end
54+
55+
@testset verbose = true "Arrays" begin
56+
X = rand(2, 3, 4)
57+
Y, _ = implicit(X)
58+
JJ = Diagonal(0.5 ./ sqrt.(vec(X)))
59+
60+
@testset "Exactness" begin
61+
@test (first implicit)(X) sqrt.(X)
62+
@test ForwardDiff.jacobian(first implicit, X) JJ
63+
@test Zygote.jacobian(first implicit, X)[1] JJ
64+
end
65+
66+
@testset "Forward type inference" begin
67+
X_and_dX = ForwardDiff.Dual.(X, ((0, 0),))
68+
@test (@inferred implicit(X_and_dX)) == implicit(X_and_dX)
69+
Y_and_dY, _ = implicit(X_and_dX)
70+
@test size(Y_and_dY) == size(Y)
71+
end
72+
73+
@testset "Reverse type inference" begin
74+
_, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, X)
75+
dY, dZ = zero(implicit(X)[1]), 0
76+
@test (@inferred pullback((dY, dZ))) == pullback((dY, dZ))
77+
_, dX = pullback((dY, dZ))
78+
@test size(dX) == size(X)
79+
end
80+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples")
5555
@testset verbose = false "Doctests (Documenter.jl)" begin
5656
doctest(ImplicitDifferentiation)
5757
end
58+
@testset verbose = true "Miscellaneous" begin
59+
include("misc.jl")
60+
end
5861
for file in readdir(EXAMPLES_DIR_JL)
5962
path = joinpath(EXAMPLES_DIR_JL, file)
6063
title = markdown_title(path)

0 commit comments

Comments
 (0)