Skip to content

Commit b35ae0a

Browse files
Symbolic normalizing constant
1 parent a6ab223 commit b35ae0a

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

pymc/distributions/multivariate.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,31 +1150,30 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
11501150

11511151

11521152
def _lkj_normalizing_constant(eta, n):
1153-
# TODO: This is mixing python branching with the potentially symbolic n and eta variables
1154-
if not isinstance(eta, int | float):
1155-
raise NotImplementedError("eta must be an int or float")
1156-
if not isinstance(n, int):
1157-
raise NotImplementedError("n must be an integer")
1158-
if eta == 1:
1159-
result = gammaln(2.0 * pt.arange(1, ((n - 1) / 2) + 1)).sum()
1160-
if n % 2 == 1:
1161-
result += (
1153+
result_1 = gammaln(2.0 * pt.arange(1, ((n - 1) / 2) + 1)).sum()
1154+
result_2 = -(n - 1) * gammaln(eta + 0.5 * (n - 1))
1155+
k = pt.arange(1, n)
1156+
1157+
return pt.switch(
1158+
pt.eq(eta, 1.0),
1159+
pt.switch(
1160+
pt.eq(n % 2, 1.0),
1161+
result_1
1162+
+ (
11621163
0.25 * (n**2 - 1) * pt.log(np.pi)
11631164
- 0.25 * (n - 1) ** 2 * pt.log(2.0)
11641165
- (n - 1) * gammaln((n + 1) / 2)
1165-
)
1166-
else:
1167-
result += (
1166+
),
1167+
result_1
1168+
+ (
11681169
0.25 * n * (n - 2) * pt.log(np.pi)
11691170
+ 0.25 * (3 * n**2 - 4 * n) * pt.log(2.0)
11701171
+ n * gammaln(n / 2)
11711172
- (n - 1) * gammaln(n)
1172-
)
1173-
else:
1174-
result = -(n - 1) * gammaln(eta + 0.5 * (n - 1))
1175-
k = pt.arange(1, n)
1176-
result += (0.5 * k * pt.log(np.pi) + gammaln(eta + 0.5 * (n - 1 - k))).sum()
1177-
return result
1173+
),
1174+
),
1175+
result_2 + (0.5 * k * pt.log(np.pi) + gammaln(eta + 0.5 * (n - 1 - k))).sum(),
1176+
)
11781177

11791178

11801179
# _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't
@@ -1603,13 +1602,17 @@ class _LKJCorr(BoundedContinuous):
16031602
def dist(cls, n, eta, **kwargs):
16041603
n = pt.as_tensor_variable(n).astype(int)
16051604
eta = pt.as_tensor_variable(eta)
1606-
rng = kwargs.pop("rng", None)
16071605

1608-
if isinstance(rng, Variable):
1609-
rng = rng.get_value()
1610-
1611-
kwargs["scan_rng"] = pytensor.shared(np.random.default_rng(rng))
1612-
kwargs["outer_rng"] = pytensor.shared(np.random.default_rng(rng))
1606+
# In general, RVs are expected to take an "rng" argument. We allow it here to prevent API break, but
1607+
# ignore it. This can be changed in the future if we relax the requirement that rng be a shared variable
1608+
# in a scan.
1609+
rng = kwargs.pop("rng", None)
1610+
if rng is not None:
1611+
warnings.warn(
1612+
"You passed a random generator to LKJCorr via the `rng` keyword argument, but it is not "
1613+
"used. To seed LKJCorr, pass two random generators via the `outer_rng` and `scan_rng` "
1614+
"keyword arguments.",
1615+
)
16131616

16141617
return super().dist([n, eta], **kwargs)
16151618

@@ -1631,17 +1634,12 @@ def logp(value: TensorVariable, n, eta):
16311634
-------
16321635
TensorVariable
16331636
"""
1634-
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1637+
# n has to be a constant, otherwise the shape of the RV would not be fixed between draws.
16351638
try:
16361639
n = int(get_underlying_scalar_constant_value(n))
16371640
except NotScalarConstantError:
16381641
raise NotImplementedError("logp only implemented for constant `n`")
16391642

1640-
try:
1641-
eta = float(get_underlying_scalar_constant_value(eta))
1642-
except NotScalarConstantError:
1643-
raise NotImplementedError("logp only implemented for constant `eta`")
1644-
16451643
result = _lkj_normalizing_constant(eta, n)
16461644
result += (eta - 1.0) * pt.log(det(value))
16471645

0 commit comments

Comments
 (0)