Skip to content

Commit 740cb46

Browse files
authored
prior:reference: rewritten handling of override_std (#460)
1 parent 9efcfb7 commit 740cb46

File tree

1 file changed

+28
-44
lines changed

1 file changed

+28
-44
lines changed

cobaya/prior.py

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def gauss_band(x_vector):
429429
from typing import Any, NamedTuple
430430

431431
import numpy as np
432+
from scipy.stats import norm # type: ignore
432433

433434
from cobaya.conventions import prior_1d_name
434435
from cobaya.log import HasLogger, LoggedError
@@ -893,59 +894,42 @@ def reference(
893894
"Reference values or pdfs for some parameters were not provided. "
894895
"Sampling from the prior instead for those parameters."
895896
)
896-
# As a curiosity, `r is np.nan` was returning False after `r = np.nan` if
897-
# it had been passed via MPI before the test, since this creates a "new" np.nan
898-
# NB: isinstance(np.nan, numbers.Real) --> True
899-
where_ignore_ref = [
900-
isinstance(r, numbers.Real) and (np.isnan(r) or ignore_fixed)
901-
for r in self.ref_pdf
902-
]
903-
904-
# Determine which parameters should use override-based perturbation vs sampling from the full prior
905-
where_use_override = (
906-
[
907-
isinstance(ref_pdf, numbers.Real)
908-
and not np.isnan(ref_pdf)
909-
and override_std.get(param_name) is not None
910-
for param_name, ref_pdf in zip(self.params, self.ref_pdf)
911-
]
912-
if override_std and ignore_fixed
913-
else [False] * len(self.ref_pdf)
914-
)
915-
897+
# Let's create an updated list of pdf|num|None using ignore_fixed and override_std
898+
updated_ref_pdfs = []
899+
i_sample_from_prior = []
900+
for i, (param, ref_pdf) in enumerate(zip(self.params, self.ref_pdf)):
901+
overriden_std = (override_std or {}).get(param)
902+
if isinstance(ref_pdf, numbers.Real):
903+
if np.isnan(ref_pdf):
904+
updated_ref_pdfs.append(None)
905+
i_sample_from_prior.append(i)
906+
elif ignore_fixed:
907+
if overriden_std is None:
908+
updated_ref_pdfs.append(None)
909+
i_sample_from_prior.append(i)
910+
else:
911+
updated_ref_pdfs.append(norm(loc=ref_pdf, scale=overriden_std))
912+
else: # actual number
913+
updated_ref_pdfs.append(ref_pdf)
914+
else: # pdf is an actual pdf
915+
updated_ref_pdfs.append(ref_pdf)
916916
tries = 0
917917
warn_if_tries = read_dnumber(warn_if_tries, self.d())
918918
ref_sample = np.empty(len(self.ref_pdf))
919919
while tries < max_tries:
920920
tries += 1
921-
921+
# Handle parameters using their reference values or pdfs.
922+
for i, pdf in enumerate(updated_ref_pdfs):
923+
if hasattr(pdf, "rvs"):
924+
ref_sample[i] = pdf.rvs(random_state=random_state)
925+
else:
926+
ref_sample[i] = pdf
922927
# Handle parameters that need sampling from prior (not using override)
923-
where_sample_prior = [
924-
where_ignore_ref[i] and not where_use_override[i]
925-
for i in range(len(self.ref_pdf))
926-
]
927-
if any(where_sample_prior):
928+
if i_sample_from_prior:
928929
prior_sample = self.sample(
929930
ignore_external=True, random_state=random_state
930931
)[0]
931-
ref_sample[where_sample_prior] = prior_sample[where_sample_prior]
932-
933-
# Handle parameters using their reference values or pdfs
934-
for i, ref_pdf in enumerate(self.ref_pdf):
935-
if not where_ignore_ref[i]:
936-
if hasattr(ref_pdf, "rvs"):
937-
ref_sample[i] = ref_pdf.rvs(random_state=random_state) # type: ignore
938-
else:
939-
ref_sample[i] = ref_pdf.real
940-
elif where_use_override[i]:
941-
# Use override-based perturbation from fixed reference
942-
param_name = self.params[i]
943-
ref_value = self.ref_pdf[i].real
944-
std = override_std[param_name]
945-
ref_sample[i] = (random_state or np.random.default_rng()).normal(
946-
ref_value, std
947-
)
948-
932+
ref_sample[i_sample_from_prior] = prior_sample[i_sample_from_prior]
949933
if self.logp(ref_sample) > -np.inf:
950934
return ref_sample
951935
if tries == warn_if_tries:

0 commit comments

Comments
 (0)