|
| 1 | +using ChainRulesCore |
| 2 | +using ChainRulesTestUtils |
| 3 | +using ForwardDiff |
| 4 | +using ImplicitDifferentiation |
| 5 | +using JET |
| 6 | +using LinearAlgebra |
| 7 | +using Random |
| 8 | +using Test |
| 9 | +using Zygote |
| 10 | + |
| 11 | +Random.seed!(63); |
| 12 | + |
| 13 | +function mysqrt(x::AbstractArray) |
| 14 | + a = [0.0] |
| 15 | + a[1] = first(x) |
| 16 | + return sqrt.(x) |
| 17 | +end |
| 18 | + |
| 19 | +forward(x) = mysqrt(x), 0 |
| 20 | +conditions(x, y, z) = y .^ 2 .- x |
| 21 | +implicit = ImplicitFunction(forward, conditions) |
| 22 | + |
| 23 | +# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities |
| 24 | +@testset verbose = true "ChainRulesTestUtils.jl" begin |
| 25 | + @test_skip test_rrule(implicit, x) |
| 26 | + @test_skip test_rrule(implicit, X) |
| 27 | +end |
| 28 | + |
| 29 | +@testset verbose = true "Vectors" begin |
| 30 | + x = rand(2) |
| 31 | + y, _ = implicit(x) |
| 32 | + J = Diagonal(0.5 ./ sqrt.(x)) |
| 33 | + |
| 34 | + @testset "Exactness" begin |
| 35 | + @test (first ∘ implicit)(x) ≈ sqrt.(x) |
| 36 | + @test ForwardDiff.jacobian(first ∘ implicit, x) ≈ J |
| 37 | + @test Zygote.jacobian(first ∘ implicit, x)[1] ≈ J |
| 38 | + end |
| 39 | + |
| 40 | + @testset verbose = true "Forward inference" begin |
| 41 | + x_and_dx = ForwardDiff.Dual.(x, ((0, 0),)) |
| 42 | + @test (@inferred implicit(x_and_dx)) == implicit(x_and_dx) |
| 43 | + y_and_dy, _ = implicit(x_and_dx) |
| 44 | + @test size(y_and_dy) == size(y) |
| 45 | + end |
| 46 | + @testset "Reverse type inference" begin |
| 47 | + _, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, x) |
| 48 | + dy, dz = zero(implicit(x)[1]), 0 |
| 49 | + @test (@inferred pullback((dy, dz))) == pullback((dy, dz)) |
| 50 | + _, dx = pullback((dy, dz)) |
| 51 | + @test size(dx) == size(x) |
| 52 | + end |
| 53 | +end |
| 54 | + |
| 55 | +@testset verbose = true "Arrays" begin |
| 56 | + X = rand(2, 3, 4) |
| 57 | + Y, _ = implicit(X) |
| 58 | + JJ = Diagonal(0.5 ./ sqrt.(vec(X))) |
| 59 | + |
| 60 | + @testset "Exactness" begin |
| 61 | + @test (first ∘ implicit)(X) ≈ sqrt.(X) |
| 62 | + @test ForwardDiff.jacobian(first ∘ implicit, X) ≈ JJ |
| 63 | + @test Zygote.jacobian(first ∘ implicit, X)[1] ≈ JJ |
| 64 | + end |
| 65 | + |
| 66 | + @testset "Forward type inference" begin |
| 67 | + X_and_dX = ForwardDiff.Dual.(X, ((0, 0),)) |
| 68 | + @test (@inferred implicit(X_and_dX)) == implicit(X_and_dX) |
| 69 | + Y_and_dY, _ = implicit(X_and_dX) |
| 70 | + @test size(Y_and_dY) == size(Y) |
| 71 | + end |
| 72 | + |
| 73 | + @testset "Reverse type inference" begin |
| 74 | + _, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, X) |
| 75 | + dY, dZ = zero(implicit(X)[1]), 0 |
| 76 | + @test (@inferred pullback((dY, dZ))) == pullback((dY, dZ)) |
| 77 | + _, dX = pullback((dY, dZ)) |
| 78 | + @test size(dX) == size(X) |
| 79 | + end |
| 80 | +end |
0 commit comments