Skip to content

Commit 78f3ef9

Browse files
committed
Refactor AD tests
1 parent cd4e62f commit 78f3ef9

File tree

8 files changed

+54
-192
lines changed

8 files changed

+54
-192
lines changed

.github/workflows/AD.yml

Lines changed: 0 additions & 45 deletions
This file was deleted.
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Interface tests
1+
name: CI
22

33
on:
44
push:
@@ -23,12 +23,17 @@ jobs:
2323
- '1'
2424
os:
2525
- ubuntu-latest
26+
group:
27+
- 'Interface'
28+
- 'AD'
29+
2630
steps:
2731
- uses: actions/checkout@v4
2832
- uses: julia-actions/setup-julia@v2
2933
with:
3034
version: ${{ matrix.version }}
35+
- uses: julia-actions/cache@v2
3136
- uses: julia-actions/julia-buildpkg@v1
3237
- uses: julia-actions/julia-runtest@v1
3338
env:
34-
GROUP: Interface
39+
GROUP: ${{ matrix.group }}

test/ad/corr.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "AD for VecCorrBijector" begin
1+
@testset "VecCorrBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
22
@testset "d = $d" for d in (1, 2, 4)
33
dist = LKJ(d, 2.0)
44
b = bijector(dist)
@@ -8,13 +8,13 @@
88
y = b(x)
99

1010
# roundtrip
11-
test_ad(y -> sum(transform(b, binv(y))), y)
11+
test_ad(y -> sum(transform(b, binv(y))), adtype, y)
1212
# inverse only
13-
test_ad(y -> sum(transform(binv, y)), y)
13+
test_ad(y -> sum(transform(binv, y)), adtype, y)
1414
end
1515
end
1616

17-
@testset "AD for VecCholeskyBijector" begin
17+
@testset "VecCholeskyBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
1818
@testset "d = $d, uplo = $uplo" for d in (1, 2, 4), uplo in ('U', 'L')
1919
dist = LKJCholesky(d, 2.0, uplo)
2020
b = bijector(dist)
@@ -24,14 +24,14 @@ end
2424
y = b(x)
2525

2626
# roundtrip
27-
test_ad(y -> sum(transform(b, binv(y))), y)
27+
test_ad(y -> sum(transform(b, binv(y))), adtype, y)
2828
# inverse (we need to tack on `cholesky_upper`/`cholesky_lower`,
2929
# because directly calling `sum` on a LinearAlgebra.Cholesky doesn't
3030
# give a scalar)
3131
if uplo == 'U'
32-
test_ad(y -> sum(Bijectors.cholesky_upper(transform(binv, y))), y)
32+
test_ad(y -> sum(Bijectors.cholesky_upper(transform(binv, y))), adtype, y)
3333
else
34-
test_ad(y -> sum(Bijectors.cholesky_lower(transform(binv, y))), y)
34+
test_ad(y -> sum(Bijectors.cholesky_lower(transform(binv, y))), adtype, y)
3535
end
3636
end
3737
end

test/ad/flows.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1-
@testset "PlanarLayer" begin
1+
@testset "PlanarLayer: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
22
# logpdf of a flow with a planar layer and two-dimensional inputs
3-
test_ad(randn(7)) do θ
3+
function f(θ)
44
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
55
flow = transformed(MvNormal(zeros(2), I), layer)
66
x = θ[6:7]
77
return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
88
end
9-
test_ad(randn(11)) do θ
9+
test_ad(f, adtype, randn(7))
10+
11+
function g(θ)
1012
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
1113
flow = transformed(MvNormal(zeros(2), I), layer)
1214
x = reshape(θ[6:end], 2, :)
1315
return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x))
1416
end
17+
test_ad(g, adtype, randn(11))
1518

1619
# logpdf of a flow with the inverse of a planar layer and two-dimensional inputs
17-
test_ad(randn(7)) do θ
20+
function finv(θ)
1821
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
1922
flow = transformed(MvNormal(zeros(2), I), inverse(layer))
2023
x = θ[6:7]
2124
return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
2225
end
23-
test_ad(randn(11)) do θ
26+
test_ad(finv, adtype, randn(7))
27+
28+
function ginv(θ)
2429
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
2530
flow = transformed(MvNormal(zeros(2), I), inverse(layer))
2631
x = reshape(θ[6:end], 2, :)
2732
return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x))
2833
end
34+
test_ad(ginv, adtype, randn(11))
2935
end

test/ad/pd.jl

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
_topd(x) = x * x' + I
22

3-
@testset "AD for PDVecBijector" begin
3+
@testset "PDVecBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
44
d = 4
55
b = Bijectors.PDVecBijector()
66
binv = inverse(b)
@@ -9,21 +9,11 @@ _topd(x) = x * x' + I
99
x = _topd(z)
1010
y = b(x)
1111

12-
test_ad(vec(z)) do x
13-
sum(transform(b, _topd(reshape(x, d, d))))
14-
end
12+
test_ad(x -> sum(transform(b, _topd(reshape(x, d, d)))), adtype, vec(z))
13+
test_ad(y -> sum(transform(binv, y)), adtype, y)
1514

16-
test_ad(y) do y
17-
sum(transform(binv, y))
18-
end
19-
20-
if AD == "ReverseDiff" # `AD` is defined in `test/ad/utils.jl`.
21-
test_ad(y) do y
22-
sum(Bijectors.cholesky_lower(transform(binv, y)))
23-
end
24-
25-
test_ad(y) do y
26-
sum(Bijectors.cholesky_upper(transform(binv, y)))
27-
end
28-
end
15+
# if occursin("ReverseDiff", backend_name)
16+
test_ad(y -> sum(Bijectors.cholesky_lower(transform(binv, y))), adtype, y)
17+
test_ad(y -> sum(Bijectors.cholesky_upper(transform(binv, y))), adtype, y)
18+
# end
2919
end

test/ad/stacked.jl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "AD for StackedBijector" begin
1+
@testset "StackedBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
22
dist1 = Dirichlet(4, 1.0)
33
b1 = bijector(dist1)
44

@@ -17,22 +17,12 @@
1717
y = vcat(y1, [y2])
1818
x = binv(y)
1919

20-
test_ad(y) do x
21-
sum(transform(b, binv(x)))
22-
end
23-
24-
test_ad(y) do y
25-
sum(transform(binv, y))
26-
end
20+
test_ad(y -> sum(transform(b, binv(y))), adtype, y)
21+
test_ad(y -> sum(transform(binv, y)), adtype, y)
2722

2823
bvec = Stacked([b1, b2], [1:4, 5:5])
2924
bvec_inv = inverse(bvec)
3025

31-
test_ad(y) do x
32-
sum(transform(bvec, binv(x)))
33-
end
34-
35-
test_ad(y) do y
36-
sum(transform(bvec_inv, y))
37-
end
26+
test_ad(y -> sum(transform(bvec, binv(y))), adtype, y)
27+
test_ad(y -> sum(transform(bvec_inv, y)), adtype, y)
3828
end

test/ad/utils.jl

Lines changed: 6 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,9 @@
1-
# Figure out which AD backend to test
2-
const AD = get(ENV, "AD", "All")
1+
using DifferentiationInterface
32

4-
function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
5-
for b in broken
6-
if !(
7-
b in (
8-
:ForwardDiff,
9-
:Mooncake,
10-
:ReverseDiff,
11-
:Enzyme,
12-
:EnzymeForward,
13-
:EnzymeReverse,
14-
# The `Crash` ones indicate that the error will cause a Julia crash, and
15-
# thus we can't even run `@test_broken on it.
16-
:EnzymeForwardCrash,
17-
:EnzymeReverseCrash,
18-
)
19-
)
20-
error("Unknown broken AD backend: $b")
21-
end
22-
end
3+
const REF_BACKEND = AutoFiniteDifferences(; fdm=central_fdm(5, 1))
234

24-
finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1]
25-
26-
if AD == "All" || AD == "ForwardDiff"
27-
if :ForwardDiff in broken
28-
@test_broken ForwardDiff.gradient(f, x) finitediff rtol = rtol atol = atol
29-
else
30-
@test ForwardDiff.gradient(f, x) finitediff rtol = rtol atol = atol
31-
end
32-
end
33-
34-
if AD == "All" || AD == "ReverseDiff"
35-
if :ReverseDiff in broken
36-
@test_broken ReverseDiff.gradient(f, x) finitediff rtol = rtol atol = atol
37-
else
38-
@test ReverseDiff.gradient(f, x) finitediff rtol = rtol atol = atol
39-
end
40-
end
41-
42-
if AD == "All" || AD == "Enzyme"
43-
forward_broken = :EnzymeForward in broken || :Enzyme in broken
44-
reverse_broken = :EnzymeReverse in broken || :Enzyme in broken
45-
if !(:EnzymeForwardCrash in broken)
46-
if forward_broken
47-
@test_broken(
48-
Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] finitediff,
49-
rtol = rtol,
50-
atol = atol
51-
)
52-
else
53-
@test(
54-
Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] finitediff,
55-
rtol = rtol,
56-
atol = atol
57-
)
58-
end
59-
end
60-
61-
if !(:EnzymeReverseCrash in broken)
62-
if reverse_broken
63-
@test_broken(
64-
Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1]
65-
finitediff,
66-
rtol = rtol,
67-
atol = atol
68-
)
69-
else
70-
@test(
71-
Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1]
72-
finitediff,
73-
rtol = rtol,
74-
atol = atol
75-
)
76-
end
77-
end
78-
end
79-
80-
if AD == "All" || AD == "Mooncake"
81-
rule = Mooncake.build_rrule(f, x)
82-
if :Mooncake in broken
83-
@test_broken isapprox(
84-
Mooncake.value_and_gradient!!(rule, f, x)[2][2],
85-
finitediff;
86-
rtol=rtol,
87-
atol=atol,
88-
)
89-
else
90-
@test isapprox(
91-
Mooncake.value_and_gradient!!(rule, f, x)[2][2],
92-
finitediff;
93-
rtol=rtol,
94-
atol=atol,
95-
)
96-
end
97-
end
98-
99-
return nothing
5+
function test_ad(f, backend, x; rtol=1e-6, atol=1e-6)
6+
ref_gradient = DifferentiationInterface.gradient(f, REF_BACKEND, x)
7+
gradient = DifferentiationInterface.gradient(f, backend, x)
8+
@test isapprox(gradient, ref_gradient; rtol=rtol, atol=atol)
1009
end

test/runtests.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using InverseFunctions: InverseFunctions
3333
using LazyArrays: LazyArrays
3434

3535
const GROUP = get(ENV, "GROUP", "All")
36+
const IS_PRERELEASE = !isempty(VERSION.prerelease)
3637

3738
# Always include this since it can be useful for other tests.
3839
include("ad/utils.jl")
@@ -52,7 +53,6 @@ if GROUP == "All" || GROUP == "Interface"
5253
include("bijectors/reshape.jl")
5354
include("bijectors/corr.jl")
5455
include("bijectors/product_bijector.jl")
55-
5656
include("distributionsad.jl")
5757

5858
@testset "doctests" begin
@@ -68,13 +68,20 @@ if GROUP == "All" || GROUP == "Interface"
6868
end
6969
end
7070

71+
TEST_ADTYPES = [
72+
("ForwardDiff", AutoForwardDiff()),
73+
("ReverseDiff", AutoReverseDiff(; compile=false)),
74+
("ReverseDiffCompiled", AutoReverseDiff(; compile=true)),
75+
]
76+
if !IS_PRERELEASE
77+
push!(TEST_ADTYPES, ("Mooncake", AutoMooncake()))
78+
end
79+
7180
if GROUP == "All" || GROUP == "AD"
7281
include("ad/chainrules.jl")
73-
if get(ENV, "AD", "All") in ("All", "Enzyme")
74-
include("ad/enzyme.jl")
75-
end
7682
include("ad/flows.jl")
7783
include("ad/pd.jl")
7884
include("ad/corr.jl")
7985
include("ad/stacked.jl")
86+
include("ad/enzyme.jl")
8087
end

0 commit comments

Comments
 (0)