@@ -1428,26 +1428,27 @@ def __init__(self, groups, model=None):
14281428 self .groups = []
14291429 seen = set ()
14301430 rest = None
1431- for g in groups :
1432- if g .group is None :
1433- if rest is not None :
1434- raise GroupError ("More than one group is specified for the rest variables" )
1431+ with model :
1432+ for g in groups :
1433+ if g .group is None :
1434+ if rest is not None :
1435+ raise GroupError ("More than one group is specified for the rest variables" )
1436+ else :
1437+ rest = g
14351438 else :
1436- rest = g
1437- else :
1438- final_group = _refresh_group_for_model (g , model )
1439- if set (final_group ) & seen :
1440- raise GroupError ("Found duplicates in groups" )
1441- seen .update (final_group )
1442- self .groups .append (g )
1443- # List iteration to preserve order for reproducibility between runs
1444- unseen_free_RVs = [var for var in model .free_RVs if var not in seen ]
1445- if unseen_free_RVs :
1446- if rest is None :
1447- raise GroupError ("No approximation is specified for the rest variables" )
1448- else :
1449- _refresh_group_for_model (rest , model , unseen_free_RVs )
1450- self .groups .append (rest )
1439+ final_group = _refresh_group_for_model (g , model )
1440+ if set (final_group ) & seen :
1441+ raise GroupError ("Found duplicates in groups" )
1442+ seen .update (final_group )
1443+ self .groups .append (g )
1444+ # List iteration to preserve order for reproducibility between runs
1445+ unseen_free_RVs = [var for var in model .free_RVs if var not in seen ]
1446+ if unseen_free_RVs :
1447+ if rest is None :
1448+ raise GroupError ("No approximation is specified for the rest variables" )
1449+ else :
1450+ rest .__init_group__ (unseen_free_RVs )
1451+ self .groups .append (rest )
14511452
14521453 @property
14531454 def has_logq (self ):
@@ -1467,11 +1468,13 @@ def _ensure_groups_ready(self, model=None):
14671468 model = modelcontext (model )
14681469 except TypeError :
14691470 return
1470- for g in self .groups :
1471- _refresh_group_for_model (g , model )
1471+ with model :
1472+ for g in self .groups :
1473+ _refresh_group_for_model (g , model )
14721474
14731475 def collect (self , item ):
1474- self ._ensure_groups_ready ()
1476+ model = modelcontext (None )
1477+ self ._ensure_groups_ready (model = model )
14751478 return [getattr (g , item ) for g in self .groups ]
14761479
14771480 def _variational_orderings (self , model ):
@@ -1483,48 +1486,51 @@ def _variational_orderings(self, model):
14831486 return orderings
14841487
14851488 def _draw_variational_samples (self , model , names , draws , size_sym , random_seed ):
1486- if not names :
1487- return {}
1488- tensors = [self .rslice (name , model ) for name in names ]
1489- tensors = self .set_size_and_deterministic (tensors , size_sym , 0 )
1490- sample_fn = compile ([size_sym ], tensors )
1491- rng_nodes = find_rng_nodes (tensors )
1492- if random_seed is not None :
1493- reseed_rngs (rng_nodes , random_seed )
1494- outputs = sample_fn (draws )
1495- if not isinstance (outputs , list | tuple ):
1496- outputs = [outputs ]
1497- return dict (zip (names , outputs ))
1489+ with model :
1490+ if not names :
1491+ return {}
1492+ tensors = [self .rslice (name , model ) for name in names ]
1493+ tensors = self .set_size_and_deterministic (tensors , size_sym , 0 )
1494+ sample_fn = compile ([size_sym ], tensors )
1495+ rng_nodes = find_rng_nodes (tensors )
1496+ if random_seed is not None :
1497+ reseed_rngs (rng_nodes , random_seed )
1498+ outputs = sample_fn (draws )
1499+ if not isinstance (outputs , list | tuple ):
1500+ outputs = [outputs ]
1501+ return dict (zip (names , outputs ))
14981502
14991503 def _draw_forward_samples (self , model , approx_samples , approx_names , draws , random_seed ):
15001504 from pymc .sampling .forward import compile_forward_sampling_function
15011505
1502- model_names = {model .rvs_to_values [v ].name : v for v in model .free_RVs }
1503- forward_names = sorted (name for name in model_names if name not in approx_names )
1504- if not forward_names :
1505- return {}
1506-
1507- forward_vars = [model_names [name ] for name in forward_names ]
1508- approx_vars = [model_names [name ] for name in approx_names if name in model_names ]
1509- sampler_fn , _ = compile_forward_sampling_function (
1510- outputs = forward_vars ,
1511- vars_in_trace = approx_vars ,
1512- basic_rvs = model .basic_RVs ,
1513- givens_dict = None ,
1514- random_seed = random_seed ,
1515- )
1516- approx_value_vars = [model .rvs_to_values [var ] for var in approx_vars ]
1517- stacked = {name : [] for name in forward_names }
1518- for i in range (draws ):
1519- inputs = {
1520- value_var .name : approx_samples [value_var .name ][i ] for value_var in approx_value_vars
1521- }
1522- raw = sampler_fn (** inputs )
1523- if not isinstance (raw , list | tuple ):
1524- raw = [raw ]
1525- for name , value in zip (forward_names , raw ):
1526- stacked [name ].append (value )
1527- return {name : np .stack (values ) for name , values in stacked .items ()}
1506+ with model :
1507+ model_names = {model .rvs_to_values [v ].name : v for v in model .free_RVs }
1508+ forward_names = sorted (name for name in model_names if name not in approx_names )
1509+ if not forward_names :
1510+ return {}
1511+
1512+ forward_vars = [model_names [name ] for name in forward_names ]
1513+ approx_vars = [model_names [name ] for name in approx_names if name in model_names ]
1514+ sampler_fn , _ = compile_forward_sampling_function (
1515+ outputs = forward_vars ,
1516+ vars_in_trace = approx_vars ,
1517+ basic_rvs = model .basic_RVs ,
1518+ givens_dict = None ,
1519+ random_seed = random_seed ,
1520+ )
1521+ approx_value_vars = [model .rvs_to_values [var ] for var in approx_vars ]
1522+ stacked = {name : [] for name in forward_names }
1523+ for i in range (draws ):
1524+ inputs = {
1525+ value_var .name : approx_samples [value_var .name ][i ]
1526+ for value_var in approx_value_vars
1527+ }
1528+ raw = sampler_fn (** inputs )
1529+ if not isinstance (raw , list | tuple ):
1530+ raw = [raw ]
1531+ for name , value in zip (forward_names , raw ):
1532+ stacked [name ].append (value )
1533+ return {name : np .stack (values ) for name , values in stacked .items ()}
15281534
15291535 def _collect_sample_vars (self , model , sample_names ):
15301536 lookup = {}
@@ -1540,28 +1546,29 @@ def _collect_sample_vars(self, model, sample_names):
15401546 return sample_vars , lookup
15411547
15421548 def _compute_missing_trace_values (self , model , samples , missing_vars ):
1543- if not missing_vars :
1544- return {}
1545- input_vars = model .value_vars
1546- base_point = model .initial_point ()
1547- point = {
1548- var .name : np .asarray (samples [var .name ][0 ])
1549- if var .name in samples
1550- else base_point [var .name ]
1551- for var in input_vars
1552- if var .name in samples or var .name in base_point
1553- }
1554- compute_fn = model .compile_fn (
1555- missing_vars ,
1556- inputs = input_vars ,
1557- on_unused_input = "ignore" ,
1558- point_fn = True ,
1559- )
1560- raw_values = compute_fn (point )
1561- if not isinstance (raw_values , list | tuple ):
1562- raw_values = [raw_values ]
1563- values = {var .name : np .asarray (value ) for var , value in zip (missing_vars , raw_values )}
1564- return values
1549+ with model :
1550+ if not missing_vars :
1551+ return {}
1552+ input_vars = model .value_vars
1553+ base_point = model .initial_point ()
1554+ point = {
1555+ var .name : np .asarray (samples [var .name ][0 ])
1556+ if var .name in samples
1557+ else base_point [var .name ]
1558+ for var in input_vars
1559+ if var .name in samples or var .name in base_point
1560+ }
1561+ compute_fn = model .compile_fn (
1562+ missing_vars ,
1563+ inputs = input_vars ,
1564+ on_unused_input = "ignore" ,
1565+ point_fn = True ,
1566+ )
1567+ raw_values = compute_fn (point )
1568+ if not isinstance (raw_values , list | tuple ):
1569+ raw_values = [raw_values ]
1570+ values = {var .name : np .asarray (value ) for var , value in zip (missing_vars , raw_values )}
1571+ return values
15651572
15661573 def _build_trace_spec (self , model , samples ):
15671574 sample_names = sorted (samples .keys ())
@@ -1819,7 +1826,7 @@ def get_optimization_replacements(self, s, d):
18191826 return repl
18201827
18211828 @pytensor .config .change_flags (compute_test_value = "off" )
1822- def sample_node (self , node , size = None , deterministic = False , more_replacements = None ):
1829+ def sample_node (self , node , size = None , deterministic = False , more_replacements = None , model = None ):
18231830 """Sample given node or nodes over shared posterior.
18241831
18251832 Parameters
@@ -1839,22 +1846,22 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
18391846 """
18401847 node_in = node
18411848
1842- model = modelcontext (None )
1843-
1844- if more_replacements :
1845- node = graph_replace (node , more_replacements , strict = False )
1846- if not isinstance (node , list | tuple ):
1847- node = [node ]
1848- node = model .replace_rvs_by_values (node )
1849- if not isinstance (node_in , list | tuple ):
1850- node = node [0 ]
1851- if size is None :
1852- node_out = self .symbolic_single_sample (node )
1853- else :
1854- node_out = self .symbolic_sample_over_posterior (node )
1855- node_out = self .set_size_and_deterministic (node_out , size , deterministic )
1856- try_to_set_test_value (node_in , node_out , size )
1857- return node_out
1849+ model = modelcontext (model )
1850+ with model :
1851+ if more_replacements :
1852+ node = graph_replace (node , more_replacements , strict = False )
1853+ if not isinstance (node , list | tuple ):
1854+ node = [node ]
1855+ node = model .replace_rvs_by_values (node )
1856+ if not isinstance (node_in , list | tuple ):
1857+ node = node [0 ]
1858+ if size is None :
1859+ node_out = self .symbolic_single_sample (node )
1860+ else :
1861+ node_out = self .symbolic_sample_over_posterior (node )
1862+ node_out = self .set_size_and_deterministic (node_out , size , deterministic )
1863+ try_to_set_test_value (node_in , node_out , size )
1864+ return node_out
18581865
18591866 def rslice (self , name , model = None ):
18601867 """*Dev* - vectorized sampling for named random variable without call to `pytensor.scan`.
@@ -1863,14 +1870,15 @@ def rslice(self, name, model=None):
18631870 """
18641871 model = modelcontext (model )
18651872
1866- for random , ordering in zip (self .symbolic_randoms , self .collect ("ordering" )):
1867- if name in ordering :
1868- _name , slc , shape , dtype = ordering [name ]
1869- found = random [..., slc ].reshape ((random .shape [0 ], * shape )).astype (dtype )
1870- found .name = name + "_vi_random_slice"
1871- break
1872- else :
1873- raise KeyError (f"{ name !r} not found" )
1873+ with model :
1874+ for random , ordering in zip (self .symbolic_randoms , self .collect ("ordering" )):
1875+ if name in ordering :
1876+ _name , slc , shape , dtype = ordering [name ]
1877+ found = random [..., slc ].reshape ((random .shape [0 ], * shape )).astype (dtype )
1878+ found .name = name + "_vi_random_slice"
1879+ break
1880+ else :
1881+ raise KeyError (f"{ name !r} not found" )
18741882 return found
18751883
18761884 @node_property
@@ -1879,15 +1887,16 @@ def sample_dict_fn(self):
18791887
18801888 def inner (draws = 100 , * , model = None , random_seed : SeedSequenceSeed = None ):
18811889 model = modelcontext (model )
1882- orderings = self ._variational_orderings (model )
1883- approx_var_names = sorted (orderings .keys ())
1884- approx_samples = self ._draw_variational_samples (
1885- model , approx_var_names , draws , s , random_seed
1886- )
1887- forward_samples = self ._draw_forward_samples (
1888- model , approx_samples , approx_var_names , draws , random_seed
1889- )
1890- return {** approx_samples , ** forward_samples }
1890+ with model :
1891+ orderings = self ._variational_orderings (model )
1892+ approx_var_names = sorted (orderings .keys ())
1893+ approx_samples = self ._draw_variational_samples (
1894+ model , approx_var_names , draws , s , random_seed
1895+ )
1896+ forward_samples = self ._draw_forward_samples (
1897+ model , approx_samples , approx_var_names , draws , random_seed
1898+ )
1899+ return {** approx_samples , ** forward_samples }
18911900
18921901 return inner
18931902
@@ -1922,37 +1931,38 @@ def sample(
19221931
19231932 model = modelcontext (model )
19241933
1925- if random_seed is not None :
1926- (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
1927- samples : dict = self .sample_dict_fn (draws , model = model , random_seed = random_seed )
1928- spec = self ._build_trace_spec (model , samples )
1929-
1930- from collections import OrderedDict
1931-
1932- default_point = model .initial_point ()
1933- value_var_names = [var .name for var in model .value_vars ]
1934- points = (
1935- OrderedDict (
1936- (
1937- name ,
1938- np .asarray (samples [name ][i ])
1939- if name in samples and len (samples [name ]) > i
1940- else np .asarray (spec .test_point .get (name , default_point [name ])),
1934+ with model :
1935+ if random_seed is not None :
1936+ (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
1937+ samples : dict = self .sample_dict_fn (draws , model = model , random_seed = random_seed )
1938+ spec = self ._build_trace_spec (model , samples )
1939+
1940+ from collections import OrderedDict
1941+
1942+ default_point = model .initial_point ()
1943+ value_var_names = [var .name for var in model .value_vars ]
1944+ points = (
1945+ OrderedDict (
1946+ (
1947+ name ,
1948+ np .asarray (samples [name ][i ])
1949+ if name in samples and len (samples [name ]) > i
1950+ else np .asarray (spec .test_point .get (name , default_point [name ])),
1951+ )
1952+ for name in value_var_names
19411953 )
1942- for name in value_var_names
1954+ for i in range ( draws )
19431955 )
1944- for i in range (draws )
1945- )
19461956
1947- trace = NDArray (
1948- model = model ,
1949- )
1950- try :
1951- trace .setup (draws = draws , chain = 0 )
1952- for point in points :
1953- trace .record (point )
1954- finally :
1955- trace .close ()
1957+ trace = NDArray (
1958+ model = model ,
1959+ )
1960+ try :
1961+ trace .setup (draws = draws , chain = 0 )
1962+ for point in points :
1963+ trace .record (point )
1964+ finally :
1965+ trace .close ()
19561966
19571967 multi_trace = MultiTrace ([trace ])
19581968 if not return_inferencedata :
0 commit comments