Skip to content

Commit 813aece

Browse files
committed
Remove discrete selfconsistency icdf
1 parent 6a7111a commit 813aece

File tree

3 files changed

+29
-78
lines changed

3 files changed

+29
-78
lines changed

pymc/testing.py

Lines changed: 27 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -668,88 +668,50 @@ def check_selfconsistency_discrete_logcdf(
668668
)
669669

670670

671-
def check_selfconsistency_continuous_icdf(
671+
def check_selfconsistency_icdf(
672672
distribution: Distribution,
673-
paramdomains: Dict[str, Domain],
674-
decimal: Optional[int] = None,
673+
paramdomains: dict[str, Domain],
674+
*,
675+
decimal: int | None = None,
675676
n_samples: int = 100,
676677
) -> None:
677-
"""
678-
Check that the icdf and logcdf functions of the distribution are consistent for a sample of probability values.
679-
"""
680-
if decimal is None:
681-
decimal = select_by_precision(float64=6, float32=3)
682-
683-
dist = create_dist_from_paramdomains(distribution, paramdomains)
684-
value = dist.type()
685-
value.name = "value"
686-
687-
dist_icdf = icdf(dist, value)
688-
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)
689-
690-
dist_logcdf = logcdf(dist, value)
691-
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
692-
693-
domains = paramdomains.copy()
694-
domains["value"] = Domain(np.linspace(0, 1, 10))
695-
696-
for point in product(domains, n_samples=n_samples):
697-
point = dict(point)
698-
value = point.pop("value")
699-
700-
with pytensor.config.change_flags(mode=Mode("py")):
701-
npt.assert_almost_equal(
702-
value,
703-
np.exp(dist_logcdf_fn(**point, value=dist_icdf_fn(**point, value=value))),
704-
decimal=decimal,
705-
err_msg=f"point: {point}, value: {value}",
706-
)
678+
"""Check that the icdf and logcdf functions of the distribution are consistent.
707679
708-
709-
def check_selfconsistency_discrete_icdf(
710-
distribution: Distribution,
711-
domain: Domain,
712-
paramdomains: Dict[str, Domain],
713-
decimal: Optional[int] = None,
714-
n_samples: int = 100,
715-
) -> None:
716-
"""
717-
Check that the icdf and logcdf functions of the distribution are
718-
consistent for a sample of values in the domain of the
719-
distribution.
680+
Only works with continuous distributions.
720681
"""
721-
722-
def ftrunc(values, decimal=0):
723-
return np.trunc(values * 10**decimal) / (10**decimal)
724-
725682
if decimal is None:
726683
decimal = select_by_precision(float64=6, float32=3)
727684

728685
dist = create_dist_from_paramdomains(distribution, paramdomains)
729-
730-
value = pt.TensorType(dtype="float64", shape=[])("value")
731-
686+
if dist.type.dtype.startswith("int"):
687+
raise NotImplementedError(
688+
"check_selfconsistency_icdf is not robust against discrete distributions."
689+
)
690+
value = dist.astype("float64").type("value")
732691
dist_icdf = icdf(dist, value)
733-
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)
692+
dist_cdf = pt.exp(logcdf(dist, value))
734693

735-
dist_logcdf = logcdf(dist, value)
736-
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
694+
py_mode = Mode("py")
695+
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf, mode=py_mode)
696+
dist_cdf_fn = compile(list(inputvars(dist_cdf)), dist_cdf, mode=py_mode)
737697

738698
domains = paramdomains.copy()
739-
domains["value"] = domain
699+
domains["value"] = Domain(np.linspace(0, 1, 10))
740700

741701
for point in product(domains, n_samples=n_samples):
742702
point = dict(point)
743703
value = point.pop("value")
744-
745-
with pytensor.config.change_flags(mode=Mode("py")):
746-
expected_value = value
747-
computed_value = dist_icdf_fn(
748-
**point, value=ftrunc(np.exp(dist_logcdf_fn(**point, value=value)), decimal=decimal)
749-
)
750-
assert (
751-
expected_value == computed_value
752-
), f"expected_value = {expected_value}, computed_value = {computed_value}, {point}"
704+
icdf_value = dist_icdf_fn(**point, value=value)
705+
recovered_value = dist_cdf_fn(
706+
**point,
707+
value=icdf_value,
708+
)
709+
np.testing.assert_almost_equal(
710+
value,
711+
recovered_value,
712+
decimal=decimal,
713+
err_msg=f"point: {point}",
714+
)
753715

754716

755717
def assert_support_point_is_expected(model, expected, check_finite_logp=True):

tests/distributions/test_continuous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
check_icdf,
4646
check_logcdf,
4747
check_logp,
48-
check_selfconsistency_continuous_icdf,
48+
check_selfconsistency_icdf,
4949
continuous_random_tester,
5050
seeded_numpy_distribution_builder,
5151
seeded_scipy_distribution_builder,
@@ -442,7 +442,7 @@ def scipy_log_cdf(value, a, b):
442442
{"a": Rplus, "b": Rplus},
443443
scipy_log_cdf,
444444
)
445-
check_selfconsistency_continuous_icdf(
445+
check_selfconsistency_icdf(
446446
pm.Kumaraswamy,
447447
{"a": Rplusbig, "b": Rplusbig},
448448
)

tests/distributions/test_discrete.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
check_icdf,
5454
check_logcdf,
5555
check_logp,
56-
check_selfconsistency_discrete_icdf,
5756
check_selfconsistency_discrete_logcdf,
5857
seeded_numpy_distribution_builder,
5958
seeded_scipy_distribution_builder,
@@ -122,11 +121,6 @@ def test_discrete_unif(self):
122121
lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1),
123122
skip_paramdomain_outside_edge_test=True,
124123
)
125-
check_selfconsistency_discrete_icdf(
126-
pm.DiscreteUniform,
127-
Rdunif,
128-
{"lower": -Rplusdunif, "upper": Rplusdunif},
129-
)
130124
# Custom logp / logcdf check for invalid parameters
131125
invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
132126
with pytensor.config.change_flags(mode=Mode("py")):
@@ -160,11 +154,6 @@ def test_geometric(self):
160154
{"p": Unit},
161155
st.geom.ppf,
162156
)
163-
check_selfconsistency_discrete_icdf(
164-
pm.Geometric,
165-
Nat,
166-
{"p": Unit},
167-
)
168157

169158
def test_hypergeometric(self):
170159
def modified_scipy_hypergeom_logcdf(value, N, k, n):

0 commit comments

Comments
 (0)