22 CorrBijector <: Bijector
33
44A bijector implementation of Stan's parametrization method for Correlation matrix:
5- https://mc-stan.org/docs/2_23/ reference-manual/correlation-matrix-transform-section.html
5+ https://mc-stan.org/docs/reference-manual/transforms.html# correlation-matrix-transform.section
66
77Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
88a correlation matrix by following readable but not that efficient form:
@@ -348,13 +348,12 @@ function _inv_link_chol_lkj(Y::AbstractMatrix)
348348 T = float (eltype (W))
349349 logJ = zero (T)
350350
351- idx = 1
352351 @inbounds for j in 1 : K
353352 log_remainder = zero (T) # log of proportion of unit vector remaining
354353 for i in 1 : (j - 1 )
355354 z = tanh (Y[i, j])
356355 W[i, j] = z * exp (log_remainder)
357- log_remainder += log1p ( - z ^ 2 ) / 2
356+ log_remainder -= LogExpFunctions . logcosh (Y[i, j])
358357 logJ += log_remainder
359358 end
360359 logJ += log_remainder
@@ -375,15 +374,18 @@ function _inv_link_chol_lkj(y::AbstractVector)
375374 T = float (eltype (W))
376375 logJ = zero (T)
377376
377+ z_vec = map (tanh, y)
378+ lc_vec = map (LogExpFunctions. logcosh, y)
379+
378380 idx = 1
379381 @inbounds for j in 1 : K
380382 log_remainder = zero (T) # log of proportion of unit vector remaining
381383 for i in 1 : (j - 1 )
382- z = tanh (y[idx])
383- idx += 1
384+ z = z_vec[idx]
384385 W[i, j] = z * exp (log_remainder)
385- log_remainder += log1p ( - z ^ 2 ) / 2
386+ log_remainder -= lc_vec[idx]
386387 logJ += log_remainder
388+ idx += 1
387389 end
388390 logJ += log_remainder
389391 W[j, j] = exp (log_remainder)
@@ -404,18 +406,19 @@ function _inv_link_chol_lkj_rrule(y::AbstractVector)
404406 T = typeof (log (one (eltype (W))))
405407 logJ = zero (T)
406408
407- z_vec = tanh .(y)
409+ z_vec = map (tanh, y)
410+ lc_vec = map (LogExpFunctions. logcosh, y)
408411
409412 idx = 1
410413 W[1 , 1 ] = 1
411414 @inbounds for j in 2 : K
412415 log_remainder = zero (T) # log of proportion of unit vector remaining
413416 for i in 1 : (j - 1 )
414417 z = z_vec[idx]
415- idx += 1
416418 W[i, j] = z * exp (log_remainder)
417- log_remainder += log1p ( - z ^ 2 ) / 2
419+ log_remainder -= lc_vec[idx]
418420 logJ += log_remainder
421+ idx += 1
419422 end
420423 logJ += log_remainder
421424 W[j, j] = exp (log_remainder)
@@ -461,13 +464,8 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix)
461464 K = LinearAlgebra. checksquare (Y)
462465
463466 result = float (zero (eltype (Y)))
464- for j in 2 : K, i in 1 : (j - 1 )
465- @inbounds abs_y_i_j = abs (Y[i, j])
466- result +=
467- (K - i + 1 ) * (
468- IrrationalConstants. logtwo -
469- (abs_y_i_j + LogExpFunctions. log1pexp (- 2 * abs_y_i_j))
470- )
467+ @inbounds for j in 2 : K, i in 1 : (j - 1 )
468+ result -= (K - i + 1 ) * LogExpFunctions. logcosh (Y[i, j])
471469 end
472470 return result
473471end
@@ -477,13 +475,8 @@ function _logabsdetjac_inv_corr(y::AbstractVector)
477475
478476 result = float (zero (eltype (y)))
479477 for (i, y_i) in enumerate (y)
480- abs_y_i = abs (y_i)
481478 row_idx = vec_to_triu1_row_index (i)
482- result +=
483- (K - row_idx + 1 ) * (
484- IrrationalConstants. logtwo -
485- (abs_y_i + LogExpFunctions. log1pexp (- 2 * abs_y_i))
486- )
479+ result -= (K - row_idx + 1 ) * LogExpFunctions. logcosh (y_i)
487480 end
488481 return result
489482end
@@ -496,10 +489,9 @@ function _logabsdetjac_inv_chol(y::AbstractVector)
496489 @inbounds for j in 2 : K
497490 tmp = zero (result)
498491 for _ in 1 : (j - 1 )
499- z = tanh (y[idx])
500- logz = log (1 - z^ 2 )
501- result += logz + (tmp / 2 )
502- tmp += logz
492+ logcoshy = LogExpFunctions. logcosh (y[idx])
493+ tmp -= logcoshy
494+ result += tmp - logcoshy
503495 idx += 1
504496 end
505497 end
0 commit comments