@@ -1150,31 +1150,30 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
11501150
11511151
11521152def _lkj_normalizing_constant (eta , n ):
1153- # TODO: This is mixing python branching with the potentially symbolic n and eta variables
1154- if not isinstance (eta , int | float ):
1155- raise NotImplementedError ("eta must be an int or float" )
1156- if not isinstance (n , int ):
1157- raise NotImplementedError ("n must be an integer" )
1158- if eta == 1 :
1159- result = gammaln (2.0 * pt .arange (1 , ((n - 1 ) / 2 ) + 1 )).sum ()
1160- if n % 2 == 1 :
1161- result += (
1153+ result_1 = gammaln (2.0 * pt .arange (1 , ((n - 1 ) / 2 ) + 1 )).sum ()
1154+ result_2 = - (n - 1 ) * gammaln (eta + 0.5 * (n - 1 ))
1155+ k = pt .arange (1 , n )
1156+
1157+ return pt .switch (
1158+ pt .eq (eta , 1.0 ),
1159+ pt .switch (
1160+ pt .eq (n % 2 , 1.0 ),
1161+ result_1
1162+ + (
11621163 0.25 * (n ** 2 - 1 ) * pt .log (np .pi )
11631164 - 0.25 * (n - 1 ) ** 2 * pt .log (2.0 )
11641165 - (n - 1 ) * gammaln ((n + 1 ) / 2 )
1165- )
1166- else :
1167- result += (
1166+ ),
1167+ result_1
1168+ + (
11681169 0.25 * n * (n - 2 ) * pt .log (np .pi )
11691170 + 0.25 * (3 * n ** 2 - 4 * n ) * pt .log (2.0 )
11701171 + n * gammaln (n / 2 )
11711172 - (n - 1 ) * gammaln (n )
1172- )
1173- else :
1174- result = - (n - 1 ) * gammaln (eta + 0.5 * (n - 1 ))
1175- k = pt .arange (1 , n )
1176- result += (0.5 * k * pt .log (np .pi ) + gammaln (eta + 0.5 * (n - 1 - k ))).sum ()
1177- return result
1173+ ),
1174+ ),
1175+ result_2 + (0.5 * k * pt .log (np .pi ) + gammaln (eta + 0.5 * (n - 1 - k ))).sum (),
1176+ )
11781177
11791178
11801179# _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't
@@ -1603,13 +1602,17 @@ class _LKJCorr(BoundedContinuous):
16031602 def dist (cls , n , eta , ** kwargs ):
16041603 n = pt .as_tensor_variable (n ).astype (int )
16051604 eta = pt .as_tensor_variable (eta )
1606- rng = kwargs .pop ("rng" , None )
16071605
1608- if isinstance (rng , Variable ):
1609- rng = rng .get_value ()
1610-
1611- kwargs ["scan_rng" ] = pytensor .shared (np .random .default_rng (rng ))
1612- kwargs ["outer_rng" ] = pytensor .shared (np .random .default_rng (rng ))
1606+ # In general, RVs are expected to take an "rng" argument. We allow it here to prevent API break, but
1607+ # ignore it. This can be changed in the future if we relax the requirement that rng be a shared variable
1608+ # in a scan.
1609+ rng = kwargs .pop ("rng" , None )
1610+ if rng is not None :
1611+ warnings .warn (
1612+ "You passed a random generator to LKJCorr via the `rng` keyword argument, but it is not "
1613+ "used. To seed LKJCorr, pass two random generators via the `outer_rng` and `scan_rng` "
1614+ "keyword arguments." ,
1615+ )
16131616
16141617 return super ().dist ([n , eta ], ** kwargs )
16151618
@@ -1631,17 +1634,12 @@ def logp(value: TensorVariable, n, eta):
16311634 -------
16321635 TensorVariable
16331636 """
1634- # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1637+ # n has to be a constant, otherwise the shape of the RV would not be fixed between draws.
16351638 try :
16361639 n = int (get_underlying_scalar_constant_value (n ))
16371640 except NotScalarConstantError :
16381641 raise NotImplementedError ("logp only implemented for constant `n`" )
16391642
1640- try :
1641- eta = float (get_underlying_scalar_constant_value (eta ))
1642- except NotScalarConstantError :
1643- raise NotImplementedError ("logp only implemented for constant `eta`" )
1644-
16451643 result = _lkj_normalizing_constant (eta , n )
16461644 result += (eta - 1.0 ) * pt .log (det (value ))
16471645
0 commit comments