@@ -1289,7 +1289,7 @@ def symbolic_normalizing_constant(self):
12891289 obs .owner .inputs [1 :],
12901290 constant_fold ([obs .owner .inputs [0 ].shape ], raise_not_constant = False ),
12911291 )
1292- for obs in self .model .observed_RVs
1292+ for obs in self ._model .observed_RVs
12931293 if isinstance (obs .owner .op , MinibatchRandomVariable )
12941294 ]
12951295 )
@@ -1315,7 +1315,7 @@ def logq_norm(self):
13151315 def _sized_symbolic_varlogp_and_datalogp (self ):
13161316 """*Dev* - computes sampled prior term from model via `pytensor.scan`."""
13171317 varlogp_s , datalogp_s = self .symbolic_sample_over_posterior (
1318- [self .model .varlogp , self .model .datalogp ]
1318+ [self ._model .varlogp , self ._model .datalogp ]
13191319 )
13201320 return varlogp_s , datalogp_s # both shape (s,)
13211321
@@ -1352,7 +1352,7 @@ def datalogp(self):
13521352 @node_property
13531353 def _single_symbolic_varlogp_and_datalogp (self ):
13541354 """*Dev* - computes sampled prior term from model via `pytensor.scan`."""
1355- varlogp , datalogp = self .symbolic_single_sample ([self .model .varlogp , self .model .datalogp ])
1355+ varlogp , datalogp = self .symbolic_single_sample ([self ._model .varlogp , self ._model .datalogp ])
13561356 return varlogp , datalogp
13571357
13581358 @node_property
@@ -1491,14 +1491,12 @@ def get_optimization_replacements(self, s, d):
14911491 return repl
14921492
14931493 @pytensor .config .change_flags (compute_test_value = "off" )
1494- def sample_node (self , node , model = None , size = None , deterministic = False , more_replacements = None ):
1494+ def sample_node (self , node , size = None , deterministic = False , more_replacements = None ):
14951495 """Sample given node or nodes over shared posterior.
14961496
14971497 Parameters
14981498 ----------
14991499 node: PyTensor Variables (or PyTensor expressions)
1500- model : Model (optional if in ``with`` context
1501- Model to be used to generate samples.
15021500 size: None or scalar
15031501 number of samples
15041502 more_replacements: `dict`
@@ -1513,7 +1511,7 @@ def sample_node(self, node, model=None, size=None, deterministic=False, more_rep
15131511 """
15141512 node_in = node
15151513
1516- model = modelcontext ( model )
1514+ model = self . _model
15171515
15181516 if more_replacements :
15191517 node = graph_replace (node , more_replacements , strict = False )
@@ -1552,18 +1550,18 @@ def vars_names(vs):
15521550 return found
15531551
15541552 @node_property
1555- def sample_dict_fn (self , model = None ):
1553+ def sample_dict_fn (self ):
15561554 s = pt .iscalar ()
15571555
1558- model = modelcontext (model )
1556+ def inner (draws = 100 , * , model = None , random_seed : SeedSequenceSeed = None ):
1557+ model = modelcontext (model )
15591558
1560- names = [model .rvs_to_values [v ].name for v in model .free_RVs ]
1561- sampled = [self .rslice (name , model ) for name in names ]
1562- sampled = self .set_size_and_deterministic (sampled , s , 0 )
1563- sample_fn = compile ([s ], sampled )
1564- rng_nodes = find_rng_nodes (sampled )
1559+ names = [model .rvs_to_values [v ].name for v in model .free_RVs ]
1560+ sampled = [self .rslice (name , model ) for name in names ]
1561+ sampled = self .set_size_and_deterministic (sampled , s , 0 )
1562+ sample_fn = compile ([s ], sampled )
1563+ rng_nodes = find_rng_nodes (sampled )
15651564
1566- def inner (draws = 100 , * , random_seed : SeedSequenceSeed = None ):
15671565 if random_seed is not None :
15681566 reseed_rngs (rng_nodes , random_seed )
15691567 _samples = sample_fn (draws )
0 commit comments