Skip to content

Commit 117b8e8

Browse files
committed
Remove test_ad calls from interface tests
1 parent c520102 commit 117b8e8

File tree

3 files changed

+26
-31
lines changed

3 files changed

+26
-31
lines changed

test/ad/corr.jl

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
11
@testset "AD for VecCorrBijector" begin
2-
d = 4
3-
dist = LKJ(d, 2.0)
4-
b = bijector(dist)
5-
binv = inverse(b)
2+
@testset "d = $d" for d in (1, 2, 4)
3+
dist = LKJ(d, 2.0)
4+
b = bijector(dist)
5+
binv = inverse(b)
66

7-
x = rand(dist)
8-
y = b(x)
7+
x = rand(dist)
8+
y = b(x)
99

10-
test_ad(y) do x
11-
sum(transform(b, binv(x)))
12-
end
13-
14-
test_ad(y) do y
15-
sum(transform(binv, y))
10+
# roundtrip
11+
test_ad(y -> sum(transform(b, binv(y))), y)
12+
# inverse only
13+
test_ad(y -> sum(transform(binv, y)), y)
1614
end
1715
end
1816

1917
@testset "AD for VecCholeskyBijector" begin
20-
d = 4
21-
dist = LKJCholesky(d, 2.0)
22-
b = bijector(dist)
23-
binv = inverse(b)
18+
@testset "d = $d, uplo = $uplo" for d in (1, 2, 4), uplo in ('U', 'L')
19+
dist = LKJCholesky(d, 2.0, uplo)
20+
b = bijector(dist)
21+
binv = inverse(b)
2422

25-
x = rand(dist)
26-
y = b(x)
27-
28-
test_ad(y) do y
29-
sum(transform(b, binv(y)))
30-
end
23+
x = rand(dist)
24+
y = b(x)
3125

32-
test_ad(y) do y
33-
sum(Bijectors.cholesky_upper(transform(binv, y)))
26+
# roundtrip
27+
test_ad(y -> sum(transform(b, binv(y))), y)
28+
# inverse only
29+
test_ad(y -> sum(transform(binv, y)), y)
30+
# additionally check that cholesky_{upper,lower} is differentiable
31+
if uplo == 'U'
32+
test_ad(y -> sum(Bijectors.cholesky_upper(transform(b, y))), y)
33+
else
34+
test_ad(y -> sum(Bijectors.cholesky_lower(transform(b, y))), y)
35+
end
3436
end
3537
end

test/ad/utils.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# Figure out which AD backend to test
2-
const AD = get(ENV, "AD", "All")
3-
41
function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
52
for b in broken
63
if !(

test/bijectors/corr.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
3232
test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)
3333
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)
3434

35-
test_ad(x -> sum(bvec(bvecinv(x))), yvec)
36-
3735
# Check that output sizes are computed correctly.
3836
tdist = transformed(dist)
3937
@test length(tdist) == length(yvec)
@@ -64,8 +62,6 @@ end
6462

6563
@test xinv.U cholesky(xinv_lkj).U
6664

67-
test_ad(x -> sum(b(binv(x))), y)
68-
6965
# test_bijector is commented out for now,
7066
# as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky)
7167
# test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)

0 commit comments

Comments
 (0)