|
1 | | -# Figure out which AD backend to test |
2 | | -const AD = get(ENV, "AD", "All") |
| 1 | +using DifferentiationInterface |
3 | 2 |
|
4 | | -function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) |
5 | | - for b in broken |
6 | | - if !( |
7 | | - b in ( |
8 | | - :ForwardDiff, |
9 | | - :Mooncake, |
10 | | - :ReverseDiff, |
11 | | - :Enzyme, |
12 | | - :EnzymeForward, |
13 | | - :EnzymeReverse, |
14 | | - # The `Crash` ones indicate that the error will cause a Julia crash, and |
15 | | - # thus we can't even run `@test_broken on it. |
16 | | - :EnzymeForwardCrash, |
17 | | - :EnzymeReverseCrash, |
18 | | - ) |
19 | | - ) |
20 | | - error("Unknown broken AD backend: $b") |
21 | | - end |
22 | | - end |
| 3 | +const REF_BACKEND = AutoFiniteDifferences(; fdm=central_fdm(5, 1)) |
23 | 4 |
|
24 | | - finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] |
25 | | - |
26 | | - if AD == "All" || AD == "ForwardDiff" |
27 | | - if :ForwardDiff in broken |
28 | | - @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol |
29 | | - else |
30 | | - @test ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol |
31 | | - end |
32 | | - end |
33 | | - |
34 | | - if AD == "All" || AD == "ReverseDiff" |
35 | | - if :ReverseDiff in broken |
36 | | - @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol |
37 | | - else |
38 | | - @test ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol |
39 | | - end |
40 | | - end |
41 | | - |
42 | | - if AD == "All" || AD == "Enzyme" |
43 | | - forward_broken = :EnzymeForward in broken || :Enzyme in broken |
44 | | - reverse_broken = :EnzymeReverse in broken || :Enzyme in broken |
45 | | - if !(:EnzymeForwardCrash in broken) |
46 | | - if forward_broken |
47 | | - @test_broken( |
48 | | - Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] ≈ finitediff, |
49 | | - rtol = rtol, |
50 | | - atol = atol |
51 | | - ) |
52 | | - else |
53 | | - @test( |
54 | | - Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] ≈ finitediff, |
55 | | - rtol = rtol, |
56 | | - atol = atol |
57 | | - ) |
58 | | - end |
59 | | - end |
60 | | - |
61 | | - if !(:EnzymeReverseCrash in broken) |
62 | | - if reverse_broken |
63 | | - @test_broken( |
64 | | - Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1] ≈ |
65 | | - finitediff, |
66 | | - rtol = rtol, |
67 | | - atol = atol |
68 | | - ) |
69 | | - else |
70 | | - @test( |
71 | | - Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1] ≈ |
72 | | - finitediff, |
73 | | - rtol = rtol, |
74 | | - atol = atol |
75 | | - ) |
76 | | - end |
77 | | - end |
78 | | - end |
79 | | - |
80 | | - if AD == "All" || AD == "Mooncake" |
81 | | - rule = Mooncake.build_rrule(f, x) |
82 | | - if :Mooncake in broken |
83 | | - @test_broken isapprox( |
84 | | - Mooncake.value_and_gradient!!(rule, f, x)[2][2], |
85 | | - finitediff; |
86 | | - rtol=rtol, |
87 | | - atol=atol, |
88 | | - ) |
89 | | - else |
90 | | - @test isapprox( |
91 | | - Mooncake.value_and_gradient!!(rule, f, x)[2][2], |
92 | | - finitediff; |
93 | | - rtol=rtol, |
94 | | - atol=atol, |
95 | | - ) |
96 | | - end |
97 | | - end |
98 | | - |
99 | | - return nothing |
| 5 | +function test_ad(f, backend, x; rtol=1e-6, atol=1e-6) |
| 6 | + ref_gradient = DifferentiationInterface.gradient(f, REF_BACKEND, x) |
| 7 | + gradient = DifferentiationInterface.gradient(f, backend, x) |
| 8 | + @test isapprox(gradient, ref_gradient; rtol=rtol, atol=atol) |
100 | 9 | end |
0 commit comments