Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 28 additions & 44 deletions cobaya/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def gauss_band(x_vector):
from typing import Any, NamedTuple

import numpy as np
from scipy.stats import norm # type: ignore

from cobaya.conventions import prior_1d_name
from cobaya.log import HasLogger, LoggedError
Expand Down Expand Up @@ -893,59 +894,42 @@ def reference(
"Reference values or pdfs for some parameters were not provided. "
"Sampling from the prior instead for those parameters."
)
# As a curiosity, `r is np.nan` was returning False after `r = np.nan` if
# it had been passed via MPI before the test, since this creates a "new" np.nan
# NB: isinstance(np.nan, numbers.Real) --> True
where_ignore_ref = [
isinstance(r, numbers.Real) and (np.isnan(r) or ignore_fixed)
for r in self.ref_pdf
]

# Determine which parameters should use override-based perturbation vs sampling from the full prior
where_use_override = (
[
isinstance(ref_pdf, numbers.Real)
and not np.isnan(ref_pdf)
and override_std.get(param_name) is not None
for param_name, ref_pdf in zip(self.params, self.ref_pdf)
]
if override_std and ignore_fixed
else [False] * len(self.ref_pdf)
)

# Let's create an updated list of pdf|num|None using ignore_fixed and override_std
updated_ref_pdfs = []
i_sample_from_prior = []
for i, (param, ref_pdf) in enumerate(zip(self.params, self.ref_pdf)):
overriden_std = (override_std or {}).get(param)
if isinstance(ref_pdf, numbers.Real):
if np.isnan(ref_pdf):
updated_ref_pdfs.append(None)
i_sample_from_prior.append(i)
elif ignore_fixed:
if overriden_std is None:
updated_ref_pdfs.append(None)
i_sample_from_prior.append(i)
else:
updated_ref_pdfs.append(norm(loc=ref_pdf, scale=overriden_std))
else: # actual number
updated_ref_pdfs.append(ref_pdf)
else: # pdf is an actual pdf
updated_ref_pdfs.append(ref_pdf)
tries = 0
warn_if_tries = read_dnumber(warn_if_tries, self.d())
ref_sample = np.empty(len(self.ref_pdf))
while tries < max_tries:
tries += 1

# Handle parameters using their reference values or pdfs.
for i, pdf in enumerate(updated_ref_pdfs):
if hasattr(pdf, "rvs"):
ref_sample[i] = pdf.rvs(random_state=random_state)
else:
ref_sample[i] = pdf
# Handle parameters that need sampling from prior (not using override)
where_sample_prior = [
where_ignore_ref[i] and not where_use_override[i]
for i in range(len(self.ref_pdf))
]
if any(where_sample_prior):
if i_sample_from_prior:
prior_sample = self.sample(
ignore_external=True, random_state=random_state
)[0]
ref_sample[where_sample_prior] = prior_sample[where_sample_prior]

# Handle parameters using their reference values or pdfs
for i, ref_pdf in enumerate(self.ref_pdf):
if not where_ignore_ref[i]:
if hasattr(ref_pdf, "rvs"):
ref_sample[i] = ref_pdf.rvs(random_state=random_state) # type: ignore
else:
ref_sample[i] = ref_pdf.real
elif where_use_override[i]:
# Use override-based perturbation from fixed reference
param_name = self.params[i]
ref_value = self.ref_pdf[i].real
std = override_std[param_name]
ref_sample[i] = (random_state or np.random.default_rng()).normal(
ref_value, std
)

ref_sample[i_sample_from_prior] = prior_sample[i_sample_from_prior]
if self.logp(ref_sample) > -np.inf:
return ref_sample
if tries == warn_if_tries:
Expand Down
Loading