Skip to content

Commit da9650c

Browse files
authored
Fix Dirichlet logpdf_with_trans to work with a Vector{Real} (#326)
1 parent dc6b21f commit da9650c

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

src/Bijectors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ function logpdf_with_trans(d::Distribution, x, transform::Bool)
164164
if ispd(d)
165165
return pd_logpdf_with_trans(d, x, transform)
166166
elseif isdirichlet(d)
167-
l = logpdf(d, x .+ eps(eltype(x)))
167+
l = logpdf(d, x .+ _eps(eltype(x)))
168168
else
169169
l = logpdf(d, x)
170170
end

test/transform.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,51 @@ function single_sample_tests(dist)
4242

4343
# Check that invlink is inverse of link.
4444
x = rand(dist)
45+
_single_sample_tests_inner(dist, x, ϵ)
4546

47+
# If the sample is a vector of scalars, check that we can run the tests even if the
48+
# vector has the abstract element type Real. Skip type stability tests though.
49+
if x isa Vector{<:Real}
50+
_single_sample_tests_inner(dist, Vector{Real}(x), ϵ, false)
51+
end
52+
end
53+
54+
function _single_sample_tests_inner(dist, x, ϵ, test_type_stability=true)
4655
if dist isa LKJCholesky
4756
x_inv = @inferred Cholesky{Float64,Matrix{Float64}} invlink(
4857
dist, link(dist, copy(x))
4958
)
5059
@test x_inv.UL x.UL atol = 1e-9
5160
else
52-
@test @inferred(invlink(dist, link(dist, copy(x)))) x atol = 1e-9
61+
x_reconstructed = if test_type_stability
62+
@inferred invlink(dist, link(dist, copy(x)))
63+
else
64+
invlink(dist, link(dist, copy(x)))
65+
end
66+
@test x_reconstructed x atol = 1e-9
5367
end
5468

5569
# Check that link is inverse of invlink. Hopefully this just holds given the above...
56-
y = @inferred(link(dist, x))
70+
y = if test_type_stability
71+
@inferred(link(dist, x))
72+
else
73+
link(dist, x)
74+
end
75+
y_reconstructed = if test_type_stability
76+
@inferred(link(dist, invlink(dist, copy(y))))
77+
else
78+
link(dist, invlink(dist, copy(y)))
79+
end
5780
if dist isa Dirichlet
5881
# `logit` and `logistic` are not perfect inverses. This leads to a diversion.
5982
# Example:
6083
# julia> logit(logistic(0.9999999999999998))
6184
# 1.0
6285
# julia> logistic(logit(0.9999999999999998))
6386
# 0.9999999999999998
64-
@test @inferred(link(dist, invlink(dist, copy(y)))) y atol = 0.5
87+
@test y_reconstructed y atol = 0.5
6588
else
66-
@test @inferred(link(dist, invlink(dist, copy(y)))) y atol = 1e-9
89+
@test y_reconstructed y atol = 1e-9
6790
end
6891
if dist isa SimplexDistribution
6992
# This should probably be exact.

0 commit comments

Comments
 (0)