Skip to content

Commit 5c15eb5

Browse files
committed
Use model context everywhere
1 parent 2d05203 commit 5c15eb5

File tree

3 files changed

+161
-151
lines changed

3 files changed

+161
-151
lines changed

pymc/variational/opvi.py

Lines changed: 154 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,26 +1428,27 @@ def __init__(self, groups, model=None):
14281428
self.groups = []
14291429
seen = set()
14301430
rest = None
1431-
for g in groups:
1432-
if g.group is None:
1433-
if rest is not None:
1434-
raise GroupError("More than one group is specified for the rest variables")
1431+
with model:
1432+
for g in groups:
1433+
if g.group is None:
1434+
if rest is not None:
1435+
raise GroupError("More than one group is specified for the rest variables")
1436+
else:
1437+
rest = g
14351438
else:
1436-
rest = g
1437-
else:
1438-
final_group = _refresh_group_for_model(g, model)
1439-
if set(final_group) & seen:
1440-
raise GroupError("Found duplicates in groups")
1441-
seen.update(final_group)
1442-
self.groups.append(g)
1443-
# List iteration to preserve order for reproducibility between runs
1444-
unseen_free_RVs = [var for var in model.free_RVs if var not in seen]
1445-
if unseen_free_RVs:
1446-
if rest is None:
1447-
raise GroupError("No approximation is specified for the rest variables")
1448-
else:
1449-
_refresh_group_for_model(rest, model, unseen_free_RVs)
1450-
self.groups.append(rest)
1439+
final_group = _refresh_group_for_model(g, model)
1440+
if set(final_group) & seen:
1441+
raise GroupError("Found duplicates in groups")
1442+
seen.update(final_group)
1443+
self.groups.append(g)
1444+
# List iteration to preserve order for reproducibility between runs
1445+
unseen_free_RVs = [var for var in model.free_RVs if var not in seen]
1446+
if unseen_free_RVs:
1447+
if rest is None:
1448+
raise GroupError("No approximation is specified for the rest variables")
1449+
else:
1450+
rest.__init_group__(unseen_free_RVs)
1451+
self.groups.append(rest)
14511452

14521453
@property
14531454
def has_logq(self):
@@ -1467,11 +1468,13 @@ def _ensure_groups_ready(self, model=None):
14671468
model = modelcontext(model)
14681469
except TypeError:
14691470
return
1470-
for g in self.groups:
1471-
_refresh_group_for_model(g, model)
1471+
with model:
1472+
for g in self.groups:
1473+
_refresh_group_for_model(g, model)
14721474

14731475
def collect(self, item):
1474-
self._ensure_groups_ready()
1476+
model = modelcontext(None)
1477+
self._ensure_groups_ready(model=model)
14751478
return [getattr(g, item) for g in self.groups]
14761479

14771480
def _variational_orderings(self, model):
@@ -1483,48 +1486,51 @@ def _variational_orderings(self, model):
14831486
return orderings
14841487

14851488
def _draw_variational_samples(self, model, names, draws, size_sym, random_seed):
1486-
if not names:
1487-
return {}
1488-
tensors = [self.rslice(name, model) for name in names]
1489-
tensors = self.set_size_and_deterministic(tensors, size_sym, 0)
1490-
sample_fn = compile([size_sym], tensors)
1491-
rng_nodes = find_rng_nodes(tensors)
1492-
if random_seed is not None:
1493-
reseed_rngs(rng_nodes, random_seed)
1494-
outputs = sample_fn(draws)
1495-
if not isinstance(outputs, list | tuple):
1496-
outputs = [outputs]
1497-
return dict(zip(names, outputs))
1489+
with model:
1490+
if not names:
1491+
return {}
1492+
tensors = [self.rslice(name, model) for name in names]
1493+
tensors = self.set_size_and_deterministic(tensors, size_sym, 0)
1494+
sample_fn = compile([size_sym], tensors)
1495+
rng_nodes = find_rng_nodes(tensors)
1496+
if random_seed is not None:
1497+
reseed_rngs(rng_nodes, random_seed)
1498+
outputs = sample_fn(draws)
1499+
if not isinstance(outputs, list | tuple):
1500+
outputs = [outputs]
1501+
return dict(zip(names, outputs))
14981502

14991503
def _draw_forward_samples(self, model, approx_samples, approx_names, draws, random_seed):
15001504
from pymc.sampling.forward import compile_forward_sampling_function
15011505

1502-
model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs}
1503-
forward_names = sorted(name for name in model_names if name not in approx_names)
1504-
if not forward_names:
1505-
return {}
1506-
1507-
forward_vars = [model_names[name] for name in forward_names]
1508-
approx_vars = [model_names[name] for name in approx_names if name in model_names]
1509-
sampler_fn, _ = compile_forward_sampling_function(
1510-
outputs=forward_vars,
1511-
vars_in_trace=approx_vars,
1512-
basic_rvs=model.basic_RVs,
1513-
givens_dict=None,
1514-
random_seed=random_seed,
1515-
)
1516-
approx_value_vars = [model.rvs_to_values[var] for var in approx_vars]
1517-
stacked = {name: [] for name in forward_names}
1518-
for i in range(draws):
1519-
inputs = {
1520-
value_var.name: approx_samples[value_var.name][i] for value_var in approx_value_vars
1521-
}
1522-
raw = sampler_fn(**inputs)
1523-
if not isinstance(raw, list | tuple):
1524-
raw = [raw]
1525-
for name, value in zip(forward_names, raw):
1526-
stacked[name].append(value)
1527-
return {name: np.stack(values) for name, values in stacked.items()}
1506+
with model:
1507+
model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs}
1508+
forward_names = sorted(name for name in model_names if name not in approx_names)
1509+
if not forward_names:
1510+
return {}
1511+
1512+
forward_vars = [model_names[name] for name in forward_names]
1513+
approx_vars = [model_names[name] for name in approx_names if name in model_names]
1514+
sampler_fn, _ = compile_forward_sampling_function(
1515+
outputs=forward_vars,
1516+
vars_in_trace=approx_vars,
1517+
basic_rvs=model.basic_RVs,
1518+
givens_dict=None,
1519+
random_seed=random_seed,
1520+
)
1521+
approx_value_vars = [model.rvs_to_values[var] for var in approx_vars]
1522+
stacked = {name: [] for name in forward_names}
1523+
for i in range(draws):
1524+
inputs = {
1525+
value_var.name: approx_samples[value_var.name][i]
1526+
for value_var in approx_value_vars
1527+
}
1528+
raw = sampler_fn(**inputs)
1529+
if not isinstance(raw, list | tuple):
1530+
raw = [raw]
1531+
for name, value in zip(forward_names, raw):
1532+
stacked[name].append(value)
1533+
return {name: np.stack(values) for name, values in stacked.items()}
15281534

15291535
def _collect_sample_vars(self, model, sample_names):
15301536
lookup = {}
@@ -1540,28 +1546,29 @@ def _collect_sample_vars(self, model, sample_names):
15401546
return sample_vars, lookup
15411547

15421548
def _compute_missing_trace_values(self, model, samples, missing_vars):
1543-
if not missing_vars:
1544-
return {}
1545-
input_vars = model.value_vars
1546-
base_point = model.initial_point()
1547-
point = {
1548-
var.name: np.asarray(samples[var.name][0])
1549-
if var.name in samples
1550-
else base_point[var.name]
1551-
for var in input_vars
1552-
if var.name in samples or var.name in base_point
1553-
}
1554-
compute_fn = model.compile_fn(
1555-
missing_vars,
1556-
inputs=input_vars,
1557-
on_unused_input="ignore",
1558-
point_fn=True,
1559-
)
1560-
raw_values = compute_fn(point)
1561-
if not isinstance(raw_values, list | tuple):
1562-
raw_values = [raw_values]
1563-
values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)}
1564-
return values
1549+
with model:
1550+
if not missing_vars:
1551+
return {}
1552+
input_vars = model.value_vars
1553+
base_point = model.initial_point()
1554+
point = {
1555+
var.name: np.asarray(samples[var.name][0])
1556+
if var.name in samples
1557+
else base_point[var.name]
1558+
for var in input_vars
1559+
if var.name in samples or var.name in base_point
1560+
}
1561+
compute_fn = model.compile_fn(
1562+
missing_vars,
1563+
inputs=input_vars,
1564+
on_unused_input="ignore",
1565+
point_fn=True,
1566+
)
1567+
raw_values = compute_fn(point)
1568+
if not isinstance(raw_values, list | tuple):
1569+
raw_values = [raw_values]
1570+
values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)}
1571+
return values
15651572

15661573
def _build_trace_spec(self, model, samples):
15671574
sample_names = sorted(samples.keys())
@@ -1819,7 +1826,7 @@ def get_optimization_replacements(self, s, d):
18191826
return repl
18201827

18211828
@pytensor.config.change_flags(compute_test_value="off")
1822-
def sample_node(self, node, size=None, deterministic=False, more_replacements=None):
1829+
def sample_node(self, node, size=None, deterministic=False, more_replacements=None, model=None):
18231830
"""Sample given node or nodes over shared posterior.
18241831
18251832
Parameters
@@ -1839,22 +1846,22 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
18391846
"""
18401847
node_in = node
18411848

1842-
model = modelcontext(None)
1843-
1844-
if more_replacements:
1845-
node = graph_replace(node, more_replacements, strict=False)
1846-
if not isinstance(node, list | tuple):
1847-
node = [node]
1848-
node = model.replace_rvs_by_values(node)
1849-
if not isinstance(node_in, list | tuple):
1850-
node = node[0]
1851-
if size is None:
1852-
node_out = self.symbolic_single_sample(node)
1853-
else:
1854-
node_out = self.symbolic_sample_over_posterior(node)
1855-
node_out = self.set_size_and_deterministic(node_out, size, deterministic)
1856-
try_to_set_test_value(node_in, node_out, size)
1857-
return node_out
1849+
model = modelcontext(model)
1850+
with model:
1851+
if more_replacements:
1852+
node = graph_replace(node, more_replacements, strict=False)
1853+
if not isinstance(node, list | tuple):
1854+
node = [node]
1855+
node = model.replace_rvs_by_values(node)
1856+
if not isinstance(node_in, list | tuple):
1857+
node = node[0]
1858+
if size is None:
1859+
node_out = self.symbolic_single_sample(node)
1860+
else:
1861+
node_out = self.symbolic_sample_over_posterior(node)
1862+
node_out = self.set_size_and_deterministic(node_out, size, deterministic)
1863+
try_to_set_test_value(node_in, node_out, size)
1864+
return node_out
18581865

18591866
def rslice(self, name, model=None):
18601867
"""*Dev* - vectorized sampling for named random variable without call to `pytensor.scan`.
@@ -1863,14 +1870,15 @@ def rslice(self, name, model=None):
18631870
"""
18641871
model = modelcontext(model)
18651872

1866-
for random, ordering in zip(self.symbolic_randoms, self.collect("ordering")):
1867-
if name in ordering:
1868-
_name, slc, shape, dtype = ordering[name]
1869-
found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype)
1870-
found.name = name + "_vi_random_slice"
1871-
break
1872-
else:
1873-
raise KeyError(f"{name!r} not found")
1873+
with model:
1874+
for random, ordering in zip(self.symbolic_randoms, self.collect("ordering")):
1875+
if name in ordering:
1876+
_name, slc, shape, dtype = ordering[name]
1877+
found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype)
1878+
found.name = name + "_vi_random_slice"
1879+
break
1880+
else:
1881+
raise KeyError(f"{name!r} not found")
18741882
return found
18751883

18761884
@node_property
@@ -1879,15 +1887,16 @@ def sample_dict_fn(self):
18791887

18801888
def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None):
18811889
model = modelcontext(model)
1882-
orderings = self._variational_orderings(model)
1883-
approx_var_names = sorted(orderings.keys())
1884-
approx_samples = self._draw_variational_samples(
1885-
model, approx_var_names, draws, s, random_seed
1886-
)
1887-
forward_samples = self._draw_forward_samples(
1888-
model, approx_samples, approx_var_names, draws, random_seed
1889-
)
1890-
return {**approx_samples, **forward_samples}
1890+
with model:
1891+
orderings = self._variational_orderings(model)
1892+
approx_var_names = sorted(orderings.keys())
1893+
approx_samples = self._draw_variational_samples(
1894+
model, approx_var_names, draws, s, random_seed
1895+
)
1896+
forward_samples = self._draw_forward_samples(
1897+
model, approx_samples, approx_var_names, draws, random_seed
1898+
)
1899+
return {**approx_samples, **forward_samples}
18911900

18921901
return inner
18931902

@@ -1922,37 +1931,38 @@ def sample(
19221931

19231932
model = modelcontext(model)
19241933

1925-
if random_seed is not None:
1926-
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
1927-
samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed)
1928-
spec = self._build_trace_spec(model, samples)
1929-
1930-
from collections import OrderedDict
1931-
1932-
default_point = model.initial_point()
1933-
value_var_names = [var.name for var in model.value_vars]
1934-
points = (
1935-
OrderedDict(
1936-
(
1937-
name,
1938-
np.asarray(samples[name][i])
1939-
if name in samples and len(samples[name]) > i
1940-
else np.asarray(spec.test_point.get(name, default_point[name])),
1934+
with model:
1935+
if random_seed is not None:
1936+
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
1937+
samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed)
1938+
spec = self._build_trace_spec(model, samples)
1939+
1940+
from collections import OrderedDict
1941+
1942+
default_point = model.initial_point()
1943+
value_var_names = [var.name for var in model.value_vars]
1944+
points = (
1945+
OrderedDict(
1946+
(
1947+
name,
1948+
np.asarray(samples[name][i])
1949+
if name in samples and len(samples[name]) > i
1950+
else np.asarray(spec.test_point.get(name, default_point[name])),
1951+
)
1952+
for name in value_var_names
19411953
)
1942-
for name in value_var_names
1954+
for i in range(draws)
19431955
)
1944-
for i in range(draws)
1945-
)
19461956

1947-
trace = NDArray(
1948-
model=model,
1949-
)
1950-
try:
1951-
trace.setup(draws=draws, chain=0)
1952-
for point in points:
1953-
trace.record(point)
1954-
finally:
1955-
trace.close()
1957+
trace = NDArray(
1958+
model=model,
1959+
)
1960+
try:
1961+
trace.setup(draws=draws, chain=0)
1962+
for point in points:
1963+
trace.record(point)
1964+
finally:
1965+
trace.close()
19561966

19571967
multi_trace = MultiTrace([trace])
19581968
if not return_inferencedata:

tests/variational/test_approximations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def test_elbo():
5555

5656
# Create variational gradient tensor
5757
mean_field = MeanField(model=model)
58-
with pytensor.config.change_flags(compute_test_value="off"):
59-
elbo = -pm.operators.KL(mean_field)()(10000)
58+
with model:
59+
with pytensor.config.change_flags(compute_test_value="off"):
60+
elbo = -pm.operators.KL(mean_field)()(10000)
6061

6162
mean_field.shared_params["mu"].set_value(post_mu)
6263
mean_field.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1))
@@ -113,9 +114,8 @@ def test_scale_cost_to_minibatch_works(aux_total_size):
113114
assert not mean_field_2.scale_cost_to_minibatch
114115
mean_field_2.shared_params["mu"].set_value(post_mu)
115116
mean_field_2.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1))
116-
117-
with pytensor.config.change_flags(compute_test_value="off"):
118-
elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000)
117+
with pytensor.config.change_flags(compute_test_value="off"):
118+
elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000)
119119

120120
np.testing.assert_allclose(
121121
elbo_via_total_size_unscaled.eval(),

0 commit comments

Comments
 (0)