Skip to content

Commit 6d09505

Browse files
authored
add ForwardDiff@1 (#378)
* add ForwardDiff@1 * Force ForwardDiff=1 in tests * Don't test Dirichlet AD at non-differentiable points * Don't run duplicate tests * Use ForwardDiff 1.0.1, fix LKJCholesky Jacobian test * Remove dead code * Don't force FD=1 in test suite * Add reference to ForwardDiff issue
1 parent 9dbf889 commit 6d09505

File tree

5 files changed

+13
-45
lines changed

5 files changed

+13
-45
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.6"
3+
version = "0.15.7"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -51,7 +51,7 @@ Distributions = "0.25.33"
5151
DistributionsAD = "0.6"
5252
DocStringExtensions = "0.9"
5353
EnzymeCore = "0.8.4"
54-
ForwardDiff = "0.10"
54+
ForwardDiff = "0.10, 1.0.1"
5555
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
5656
InverseFunctions = "0.1"
5757
IrrationalConstants = "0.1, 0.2"

src/Bijectors.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ _eps(::Type{<:Integer}) = eps(Float64)
8686

8787
function _clamp(x, a, b)
8888
T = promote_type(typeof(x), typeof(a), typeof(b))
89-
ϵ = _eps(T)
9089
clamped_x = ifelse(x < a, convert(T, a), ifelse(x > b, convert(T, b), x))
9190
DEBUG && _debug("x = $x, bounds = $((a, b)), clamped_x = $clamped_x")
9291
return clamped_x

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Enzyme = "0.13.12"
3939
EnzymeTestUtils = "0.2.1"
4040
FillArrays = "1"
4141
FiniteDifferences = "0.11, 0.12"
42-
ForwardDiff = "0.10.12"
42+
ForwardDiff = "0.10, 1.0.1"
4343
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
4444
InverseFunctions = "0.1"
4545
LazyArrays = "1, 2"

test/interface.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,10 @@ end
145145
@testset "Multivariate" begin
146146
vector_dists = [
147147
Dirichlet(2, 3),
148-
Dirichlet([1000 * one(Float64), eps(Float64)]),
149-
Dirichlet([eps(Float64), 1000 * one(Float64)]),
148+
Dirichlet([10.0, 0.1]),
149+
Dirichlet([0.1, 10.0]),
150150
MvNormal(randn(10), Diagonal(exp.(randn(10)))),
151151
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
152-
Dirichlet([1000 * one(Float64), eps(Float64)]),
153-
Dirichlet([eps(Float64), 1000 * one(Float64)]),
154152
MvTDist(1, randn(10), Matrix(Diagonal(exp.(randn(10))))),
155153
transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
156154
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))),
@@ -173,15 +171,7 @@ end
173171
# similar to what we do in test/transform.jl for Dirichlet
174172
if dist isa Dirichlet
175173
b = Bijectors.SimplexBijector()
176-
# HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]`
177-
# which in turn will lead to differences between `ForwardDiff.jacobian`
178-
# and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`.
179-
# We therefore test the realizations _on_ the boundary rather if we're near the boundary.
180-
x = if any(rand(dist) .> 0.9999)
181-
[0.0, 1.0][sortperm(rand(dist))]
182-
else
183-
rand(dist)
184-
end
174+
x = rand(dist)
185175
y = b(x)
186176
@test b(param(x)) isa TrackedArray
187177
@test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1]

test/transform.jl

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -153,32 +153,10 @@ end
153153
Dirichlet([eps(Float64), 1000 * one(Float64)]),
154154
MvNormal(randn(10), Diagonal(exp.(randn(10)))),
155155
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
156-
Dirichlet([1000 * one(Float64), eps(Float64)]),
157-
Dirichlet([eps(Float64), 1000 * one(Float64)]),
158156
]
159157
for dist in vector_dists
160158
if dist isa Dirichlet
161159
single_sample_tests(dist)
162-
163-
# This should fail at the minute. Not sure what the correct way to test this is.
164-
165-
# Workaround for intermittent test failures, result of `logpdf_with_trans(dist, x, true)`
166-
# is incorrect for `x == [0.9999999999999998, 0.0]`:
167-
x =
168-
if params(dist) ==
169-
params(Dirichlet([1000 * one(Float64), eps(Float64)]))
170-
[1.0, 0.0]
171-
else
172-
rand(dist)
173-
end
174-
# `Dirichlet` is no longer mapping between spaces of the same dimensionality,
175-
# so the block below no longer works.
176-
if !(dist isa Dirichlet)
177-
logpdf_turing = logpdf_with_trans(dist, x, true)
178-
J = ForwardDiff.jacobian(x -> link(dist, x), x)
179-
@test logpdf(dist, x .+ ϵ) - _logabsdet(J) logpdf_turing
180-
end
181-
182160
# Issue #12
183161
stepsize = 1e10
184162
dim = Bijectors.output_length(bijector(dist), length(dist))
@@ -240,14 +218,15 @@ end
240218
@testset "uplo: $uplo" for uplo in [:L, :U]
241219
dist = LKJCholesky(3, 1, uplo)
242220
single_sample_tests(dist)
243-
244221
x = rand(dist)
245-
246-
inds = [
247-
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
248-
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
249-
]
250222
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
223+
# Remove columns of Jacobian that are all zero (i.e. those
224+
# corresponding to entries above the diagonal for uplo = :U, or below
225+
# the diagonal for uplo = :L). This slightly unscientific approach
226+
# based on filter() is needed to handle both ForwardDiff 0.10 and 1 as
227+
# the exact indices will differ for the two versions; see
228+
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/738.
229+
inds = filter(i -> !all(iszero, J[:, i]), 1:size(J, 2))
251230
J = J[:, inds]
252231
logpdf_turing = logpdf_with_trans(dist, x, true)
253232
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing

0 commit comments

Comments
 (0)