@@ -429,6 +429,7 @@ def gauss_band(x_vector):
429429from typing import Any , NamedTuple
430430
431431import numpy as np
432+ from scipy .stats import norm # type: ignore
432433
433434from cobaya .conventions import prior_1d_name
434435from cobaya .log import HasLogger , LoggedError
@@ -893,59 +894,42 @@ def reference(
893894 "Reference values or pdfs for some parameters were not provided. "
894895 "Sampling from the prior instead for those parameters."
895896 )
896- # As a curiosity, `r is np.nan` was returning False after `r = np.nan` if
897- # it had been passed via MPI before the test, since this creates a "new" np.nan
898- # NB: isinstance(np.nan, numbers.Real) --> True
899- where_ignore_ref = [
900- isinstance (r , numbers .Real ) and (np .isnan (r ) or ignore_fixed )
901- for r in self .ref_pdf
902- ]
903-
904- # Determine which parameters should use override-based perturbation vs sampling from the full prior
905- where_use_override = (
906- [
907- isinstance (ref_pdf , numbers .Real )
908- and not np .isnan (ref_pdf )
909- and override_std .get (param_name ) is not None
910- for param_name , ref_pdf in zip (self .params , self .ref_pdf )
911- ]
912- if override_std and ignore_fixed
913- else [False ] * len (self .ref_pdf )
914- )
915-
897+ # Let's create an updated list of pdf|num|None using ignore_fixed and override_std
898+ updated_ref_pdfs = []
899+ i_sample_from_prior = []
900+ for i , (param , ref_pdf ) in enumerate (zip (self .params , self .ref_pdf )):
901+ overriden_std = (override_std or {}).get (param )
902+ if isinstance (ref_pdf , numbers .Real ):
903+ if np .isnan (ref_pdf ):
904+ updated_ref_pdfs .append (None )
905+ i_sample_from_prior .append (i )
906+ elif ignore_fixed :
907+ if overriden_std is None :
908+ updated_ref_pdfs .append (None )
909+ i_sample_from_prior .append (i )
910+ else :
911+ updated_ref_pdfs .append (norm (loc = ref_pdf , scale = overriden_std ))
912+ else : # actual number
913+ updated_ref_pdfs .append (ref_pdf )
914+ else : # pdf is an actual pdf
915+ updated_ref_pdfs .append (ref_pdf )
916916 tries = 0
917917 warn_if_tries = read_dnumber (warn_if_tries , self .d ())
918918 ref_sample = np .empty (len (self .ref_pdf ))
919919 while tries < max_tries :
920920 tries += 1
921-
921+ # Handle parameters using their reference values or pdfs.
922+ for i , pdf in enumerate (updated_ref_pdfs ):
923+ if hasattr (pdf , "rvs" ):
924+ ref_sample [i ] = pdf .rvs (random_state = random_state )
925+ else :
926+ ref_sample [i ] = pdf
922927 # Handle parameters that need sampling from prior (not using override)
923- where_sample_prior = [
924- where_ignore_ref [i ] and not where_use_override [i ]
925- for i in range (len (self .ref_pdf ))
926- ]
927- if any (where_sample_prior ):
928+ if i_sample_from_prior :
928929 prior_sample = self .sample (
929930 ignore_external = True , random_state = random_state
930931 )[0 ]
931- ref_sample [where_sample_prior ] = prior_sample [where_sample_prior ]
932-
933- # Handle parameters using their reference values or pdfs
934- for i , ref_pdf in enumerate (self .ref_pdf ):
935- if not where_ignore_ref [i ]:
936- if hasattr (ref_pdf , "rvs" ):
937- ref_sample [i ] = ref_pdf .rvs (random_state = random_state ) # type: ignore
938- else :
939- ref_sample [i ] = ref_pdf .real
940- elif where_use_override [i ]:
941- # Use override-based perturbation from fixed reference
942- param_name = self .params [i ]
943- ref_value = self .ref_pdf [i ].real
944- std = override_std [param_name ]
945- ref_sample [i ] = (random_state or np .random .default_rng ()).normal (
946- ref_value , std
947- )
948-
932+ ref_sample [i_sample_from_prior ] = prior_sample [i_sample_from_prior ]
949933 if self .logp (ref_sample ) > - np .inf :
950934 return ref_sample
951935 if tries == warn_if_tries :
0 commit comments