Skip to content

Commit 82d4cdf

Browse files
authored
Cholesky numerical stability: Forward transform (#357)
* Numerical improvements to LKJCholesky forward transform * Update implementation of Cholesky forward transform in rrule * Fix FD Jacobian test * Fix `_link_lkj_chol...` rrule test * Fix ChainRules test (properly this time) * Bump patch
1 parent fa5e9e4 commit 82d4cdf

File tree

5 files changed

+172
-68
lines changed

5 files changed

+172
-68
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.7"
3+
version = "0.15.8"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/bijectors/corr.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,15 @@ which is the above implementation.
293293
function _link_chol_lkj(W::AbstractMatrix)
294294
K = LinearAlgebra.checksquare(W)
295295

296-
y = similar(W) # z is also UpperTriangular.
296+
y = similar(W) # W is upper triangular.
297297
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.
298298

299299
@inbounds for j in 1:K
300-
remainder_sq = one(eltype(W))
301-
for i in 1:(j - 1)
300+
remainder_sq = W[j, j]^2
301+
for i in (j - 1):-1:1
302302
z = W[i, j] / sqrt(remainder_sq)
303-
y[i, j] = atanh(z)
304-
remainder_sq -= W[i, j]^2
303+
y[i, j] = asinh(z)
304+
remainder_sq += W[i, j]^2
305305
end
306306
for i in j:K
307307
y[i, j] = 0
@@ -317,17 +317,18 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix)
317317

318318
y = similar(W, N)
319319

320-
idx = 1
320+
starting_idx = 1
321321
@inbounds for j in 2:K
322-
y[idx] = atanh(W[1, j])
323-
idx += 1
324-
remainder_sq = 1 - W[1, j]^2
325-
for i in 2:(j - 1)
322+
y[starting_idx] = atanh(W[1, j])
323+
starting_idx += 1
324+
remainder_sq = W[j, j]^2
325+
for i in (j - 1):-1:2
326+
idx = starting_idx + i - 2
326327
z = W[i, j] / sqrt(remainder_sq)
327-
y[idx] = atanh(z)
328-
remainder_sq -= W[i, j]^2
329-
idx += 1
328+
y[idx] = asinh(z)
329+
remainder_sq += W[i, j]^2
330330
end
331+
starting_idx += length((j - 1):-1:2)
331332
end
332333

333334
return y

src/chainrules.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -161,21 +161,23 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_upper), W::AbstractMa
161161
N = ((K - 1) * K) ÷ 2
162162

163163
z = zeros(eltype(W), N)
164-
tmp_vec = similar(z)
164+
remainders = similar(z)
165165

166-
idx = 1
166+
starting_idx = 1
167167
@inbounds for j in 2:K
168-
z[idx] = atanh(W[1, j])
169-
tmp = sqrt(1 - W[1, j]^2)
170-
tmp_vec[idx] = tmp
171-
idx += 1
172-
for i in 2:(j - 1)
173-
p = W[i, j] / tmp
174-
tmp *= sqrt(1 - p^2)
175-
tmp_vec[idx] = tmp
176-
z[idx] = atanh(p)
177-
idx += 1
168+
z[starting_idx] = atanh(W[1, j])
169+
remainder_sq = W[j, j]^2
170+
starting_idx += 1
171+
for i in (j - 1):-1:2
172+
idx = starting_idx + i - 2
173+
remainder = sqrt(remainder_sq)
174+
remainders[idx] = remainder
175+
zt = W[i, j] / remainder
176+
z[idx] = asinh(zt)
177+
remainder_sq += W[i, j]^2
178178
end
179+
remainders[starting_idx - 1] = sqrt(remainder_sq)
180+
starting_idx += length((j - 1):-1:2)
179181
end
180182

181183
function pullback_link_chol_lkj_from_upper(Δz_thunked)
@@ -190,7 +192,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_upper), W::AbstractMa
190192
ΔW[j, j] = 0
191193
Δtmp = zero(eltype(Δz))
192194
for i in (j - 1):-1:2
193-
tmp = tmp_vec[idx_up_to_prev_column + i - 1]
195+
tmp = remainders[idx_up_to_prev_column + i - 1]
194196
p = W[i, j] / tmp
195197
ftmp = sqrt(1 - p^2)
196198
d_ftmp_p = -p / ftmp
@@ -216,21 +218,23 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMa
216218
N = ((K - 1) * K) ÷ 2
217219

218220
z = zeros(eltype(W), N)
219-
tmp_vec = similar(z)
221+
remainders = similar(z)
220222

221-
idx = 1
223+
starting_idx = 1
222224
@inbounds for i in 2:K
223-
z[idx] = atanh(W[i, 1])
224-
tmp = sqrt(1 - W[i, 1]^2)
225-
tmp_vec[idx] = tmp
226-
idx += 1
227-
for j in 2:(i - 1)
228-
p = W[i, j] / tmp
229-
tmp *= sqrt(1 - p^2)
230-
tmp_vec[idx] = tmp
231-
z[idx] = atanh(p)
232-
idx += 1
225+
z[starting_idx] = atanh(W[i, 1])
226+
remainder_sq = W[i, i]^2
227+
starting_idx += 1
228+
for j in (i - 1):-1:2
229+
idx = starting_idx + j - 2
230+
remainder = sqrt(remainder_sq)
231+
remainders[idx] = remainder
232+
zt = W[i, j] / remainder
233+
z[idx] = asinh(zt)
234+
remainder_sq += W[i, j]^2
233235
end
236+
remainders[starting_idx - 1] = sqrt(remainder_sq)
237+
starting_idx += length((i - 1):-1:2)
234238
end
235239

236240
function pullback_link_chol_lkj_from_lower(Δz_thunked)
@@ -245,7 +249,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMa
245249
ΔW[i, i] = 0
246250
Δtmp = zero(eltype(Δz))
247251
for j in (i - 1):-1:2
248-
tmp = tmp_vec[idx_up_to_prev_row + j - 1]
252+
tmp = remainders[idx_up_to_prev_row + j - 1]
249253
p = W[i, j] / tmp
250254
ftmp = sqrt(1 - p^2)
251255
d_ftmp_p = -p / ftmp

test/ad/chainrules.jl

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Random: Xoshiro
2+
using LinearAlgebra
13
using ChainRulesTestUtils: ChainRulesCore
24

35
# HACK: This is a workaround to test `Bijectors._inv_link_chol_lkj` which produces an
@@ -77,30 +79,104 @@ end
7779
test_rrule(Bijectors._transform_inverse_ordered, b(rand(5, 2)))
7880

7981
# LKJ and LKJCholesky bijector
80-
dist = LKJCholesky(3, 4)
8182
# Run multiple tests because we're working with `undef` entries, and so we
8283
# want to make sure that we hit cases where the `undef` entries have different values.
8384
# It's also just useful to test numerical stability for different realizations of `dist`.
85+
86+
# NOTE(penelopeysm): https://github.com/TuringLang/Bijectors.jl/pull/357
87+
# changed the implementation of _link_chol_lkj... to improve its numerical stability.
88+
# The new implementation relies on the fact that the `LKJCholesky` distribution
89+
# yields samples for which each column is a unit vector. Naively using FiniteDifferences
90+
# to calculate a JVP (as ChainRulesTestUtils.test_rrule does) does not work, because
91+
# FD does not know about this constraint.
92+
# To solve this, we run the FD part of the test with the inputs projected onto a
93+
# subspace that has that constraint encoded. We have to then recover the original
94+
# output by un-projecting.
95+
# The PR linked above has a more detailed explanation.
8496
for i in 1:30
85-
x = rand(dist)
86-
test_rrule(
87-
Bijectors._link_chol_lkj_from_upper,
88-
x.U;
89-
testset_name="_link_chol_lkj_from_upper on $(typeof(x)) [$i]",
90-
)
91-
test_rrule(
92-
Bijectors._link_chol_lkj_from_lower,
93-
x.L;
94-
testset_name="_link_chol_lkj_from_lower on $(typeof(x)) [$i]",
95-
)
97+
dist = LKJCholesky(3, 4)
98+
rng = Xoshiro(i)
99+
spl = rand(rng, dist)
96100

97-
b = bijector(dist)
98-
y = b(x)
101+
@testset "_inv_link_chol_lkj" begin
102+
# This one doesn't need the fancy projection bits, so we can just
103+
# use test_rrule as usual.
104+
x = spl
105+
b = bijector(dist)
106+
y = b(x)
107+
test_rrule(
108+
_inv_link_chol_lkj_wrapper,
109+
y;
110+
testset_name="_inv_link_chol_lkj on $(typeof(x)) [$i]",
111+
)
112+
end
99113

100-
test_rrule(
101-
_inv_link_chol_lkj_wrapper,
102-
y;
103-
testset_name="_inv_link_chol_lkj on $(typeof(x)) [$i]",
104-
)
114+
# Set up a random tangent.
115+
ybar = rand(rng, 3) * 10
116+
fdm = FiniteDifferences.central_fdm(5, 1)
117+
118+
# Functions to convert input to/from free parameters
119+
to_free_params(x::UpperTriangular) = [x[1, 2], x[1, 3], x[2, 3]]
120+
to_free_params(x::LowerTriangular) = [x[2, 1], x[3, 1], x[3, 2]]
121+
function from_x_free(x_free::AbstractVector, uplo::Symbol)
122+
x = UpperTriangular(zeros(eltype(x_free), 3, 3))
123+
x[1, 1] = 1
124+
x[1, 2] = x_free[1]
125+
x[1, 3] = x_free[2]
126+
x[2, 2] = sqrt(1 - x_free[1]^2)
127+
x[2, 3] = x_free[3]
128+
x[3, 3] = sqrt(1 - x_free[2]^2 - x_free[3]^2)
129+
return uplo == :U ? x : transpose(x)
130+
end
131+
# Function to reconvert the adjoint back into a triangular matrix
132+
function fd_xbar_to_cr_xbar(fd_xbar::AbstractVector, uplo::Symbol)
133+
x = UpperTriangular(zeros(eltype(fd_xbar), 3, 3))
134+
x[1, 2] = fd_xbar[1]
135+
x[1, 3] = fd_xbar[2]
136+
x[2, 3] = fd_xbar[3]
137+
return uplo == :U ? x : transpose(x)
138+
end
139+
140+
@testset "_link_chol_lkj_from_upper" begin
141+
f = Bijectors._link_chol_lkj_from_upper
142+
x = spl.U
143+
144+
# test primal is accurate
145+
y = f(x)
146+
cr_y, cr_pullback = ChainRulesCore.rrule(f, x)
147+
@test isapprox(y, cr_y)
148+
149+
# test that the primal still works when going via free parameters
150+
f_via_free(x_free::AbstractVector) = f(from_x_free(x_free, :U))
151+
x_free = to_free_params(x)
152+
y_via_free = f_via_free(x_free)
153+
@test isapprox(y, y_via_free)
154+
155+
# test pullback
156+
cr_xbar = cr_pullback(ybar)[2]
157+
fd_xbar = FiniteDifferences.j′vp(fdm, f_via_free, ybar, x_free)[1]
158+
@test isapprox(cr_xbar, fd_xbar_to_cr_xbar(fd_xbar, :U))
159+
end
160+
161+
@testset "_link_chol_lkj_from_lower" begin
162+
f = Bijectors._link_chol_lkj_from_lower
163+
x = spl.L
164+
165+
# test primal is accurate
166+
y = f(x)
167+
cr_y, cr_pullback = ChainRulesCore.rrule(f, x)
168+
@test isapprox(y, cr_y)
169+
170+
# test that the primal still works when going via free parameters
171+
f_via_free(x_free::AbstractVector) = f(from_x_free(x_free, :L))
172+
x_free = to_free_params(x)
173+
y_via_free = f_via_free(x_free)
174+
@test isapprox(y, y_via_free)
175+
176+
# test pullback
177+
cr_xbar = cr_pullback(ybar)[2]
178+
fd_xbar = FiniteDifferences.j′vp(fdm, f_via_free, ybar, x_free)[1]
179+
@test isapprox(cr_xbar, fd_xbar_to_cr_xbar(fd_xbar, :L))
180+
end
105181
end
106182
end

test/transform.jl

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,19 +215,42 @@ end
215215
end
216216

217217
@testset "LKJCholesky" begin
218+
# Convert Cholesky factor to its free parameters, i.e. its off-diagonal elements
219+
function chol_3by3_to_free_params(x::Cholesky)
220+
if x.uplo == :U
221+
return [x.U[1, 2], x.U[1, 3], x.U[2, 3]]
222+
else
223+
return [x.L[2, 1], x.L[3, 1], x.L[3, 2]]
224+
end
225+
end
226+
227+
# Reconstruct Cholesky factor from its free parameters
228+
# Note that x[i, i] is always positive so we don't need to worry about the sign
229+
function free_params_to_chol_3by3(free_params::AbstractVector, uplo::Symbol)
230+
x = UpperTriangular(zeros(eltype(free_params), 3, 3))
231+
x[1, 1] = 1
232+
x[1, 2] = free_params[1]
233+
x[1, 3] = free_params[2]
234+
x[2, 2] = sqrt(1 - free_params[1]^2)
235+
x[2, 3] = free_params[3]
236+
x[3, 3] = sqrt(1 - free_params[2]^2 - free_params[3]^2)
237+
if uplo == :U
238+
return Cholesky(x)
239+
else
240+
return Cholesky(transpose(x))
241+
end
242+
end
243+
218244
@testset "uplo: $uplo" for uplo in [:L, :U]
219245
dist = LKJCholesky(3, 1, uplo)
220246
single_sample_tests(dist)
221247
x = rand(dist)
222-
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
223-
# Remove columns of Jacobian that are all zero (i.e. those
224-
# corresponding to entries above the diagonal for uplo = :U, or below
225-
# the diagonal for uplo = :L). This slightly unscientific approach
226-
# based on filter() is needed to handle both ForwardDiff 0.10 and 1 as
227-
# the exact indices will differ for the two versions; see
228-
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/738.
229-
inds = filter(i -> !all(iszero, J[:, i]), 1:size(J, 2))
230-
J = J[:, inds]
248+
# Here, we need to pass ForwardDiff only the free parameters of the
249+
# Cholesky factor so that we get a square Jacobian matrix
250+
free_params = chol_3by3_to_free_params(x)
251+
J = ForwardDiff.jacobian(
252+
z -> link(dist, free_params_to_chol_3by3(z, uplo)), free_params
253+
)
231254
logpdf_turing = logpdf_with_trans(dist, x, true)
232255
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
233256
end

0 commit comments

Comments
 (0)