@@ -85,6 +85,13 @@ def create_shared_params(self, start=None, start_sigma=None):
8585 # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
8686 # future code changes that break this assumption.
8787 start = self ._prepare_start (start )
88+ # Ensure start is a 1D array and matches ddim
89+ start = np .asarray (start ).flatten ()
90+ if start .size != self .ddim :
91+ raise ValueError (
92+ f"Start array size mismatch: got { start .size } , expected { self .ddim } . "
93+ f"Start shape: { start .shape if hasattr (start , 'shape' ) else 'unknown' } "
94+ )
8895 rho1 = np .zeros ((self .ddim ,))
8996
9097 if start_sigma is not None :
@@ -139,6 +146,13 @@ def __init_group__(self, group):
139146
140147 def create_shared_params (self , start = None ):
141148 start = self ._prepare_start (start )
149+ # Ensure start is a 1D array and matches ddim
150+ start = np .asarray (start ).flatten ()
151+ if start .size != self .ddim :
152+ raise ValueError (
153+ f"Start array size mismatch: got { start .size } , expected { self .ddim } . "
154+ f"Start shape: { start .shape if hasattr (start , 'shape' ) else 'unknown' } "
155+ )
142156 n = self .ddim
143157 L_tril = np .eye (n )[np .tril_indices (n )].astype (pytensor .config .floatX )
144158 return {"mu" : pytensor .shared (start , "mu" ), "L_tril" : pytensor .shared (L_tril , "L_tril" )}
@@ -233,17 +247,19 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
233247 return {"histogram" : pytensor .shared (pm .floatX (histogram ), "histogram" )}
234248
235249 def _check_trace (self ):
250+ from pymc .model import modelcontext
251+
236252 trace = self ._kwargs .get ("trace" , None )
237253 if isinstance (trace , InferenceData ):
238254 raise NotImplementedError (
239255 "The `Empirical` approximation does not yet support `InferenceData` inputs."
240256 " Pass `pm.sample(return_inferencedata=False)` to get a `MultiTrace` to use with `Empirical`."
241257 " Please help us to refactor: https://github.com/pymc-devs/pymc/issues/5884"
242258 )
243- elif trace is not None and not all (
244- self . model . rvs_to_values [ var ]. name in trace . varnames for var in self . group
245- ):
246- raise ValueError ("trace has not all free RVs in the group" )
259+ elif trace is not None :
260+ model = modelcontext ( None )
261+ if not all ( model . rvs_to_values [ var ]. name in trace . varnames for var in self . group ):
262+ raise ValueError ("trace has not all free RVs in the group" )
247263
248264 def randidx (self , size = None ):
249265 if size is None :
0 commit comments