Skip to content

Commit 2d05203

Browse files
committed
Cleanup
1 parent 41cfc20 commit 2d05203

File tree

1 file changed

+5
-39
lines changed

1 file changed

+5
-39
lines changed

pymc/variational/opvi.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,10 +1382,6 @@ def _rebuild_group_mappings(group, model):
13821382
class TraceSpec:
13831383
sample_vars: list
13841384
test_point: collections.OrderedDict
1385-
computed_var_names: list[str]
1386-
input_vars: list
1387-
compute_fn: Any
1388-
value_var_names: list[str]
13891385

13901386

13911387
class Approximation(WithMemoization):
@@ -1545,7 +1541,7 @@ def _collect_sample_vars(self, model, sample_names):
15451541

15461542
def _compute_missing_trace_values(self, model, samples, missing_vars):
15471543
if not missing_vars:
1548-
return {}, [], [], None
1544+
return {}
15491545
input_vars = model.value_vars
15501546
base_point = model.initial_point()
15511547
point = {
@@ -1565,7 +1561,7 @@ def _compute_missing_trace_values(self, model, samples, missing_vars):
15651561
if not isinstance(raw_values, list | tuple):
15661562
raw_values = [raw_values]
15671563
values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)}
1568-
return values, [var.name for var in missing_vars], list(input_vars), compute_fn
1564+
return values
15691565

15701566
def _build_trace_spec(self, model, samples):
15711567
sample_names = sorted(samples.keys())
@@ -1586,43 +1582,15 @@ def _build_trace_spec(self, model, samples):
15861582
continue
15871583
missing_vars.append(var)
15881584

1589-
values, computed_var_names, input_vars, compute_fn = self._compute_missing_trace_values(
1590-
model, samples, missing_vars
1591-
)
1585+
values = self._compute_missing_trace_values(model, samples, missing_vars)
15921586
for name, value in values.items():
15931587
test_point[name] = value
15941588

15951589
return TraceSpec(
15961590
sample_vars=sample_vars,
15971591
test_point=test_point,
1598-
computed_var_names=computed_var_names,
1599-
input_vars=input_vars,
1600-
compute_fn=compute_fn,
1601-
value_var_names=[var.name for var in model.value_vars],
16021592
)
16031593

1604-
def _augment_samples_with_computed(self, model, samples, spec, draws):
1605-
if not spec.computed_var_names:
1606-
return
1607-
1608-
computed = {name: [] for name in spec.computed_var_names}
1609-
input_names = [var.name for var in spec.input_vars]
1610-
for i in range(draws):
1611-
inputs = {}
1612-
for name in input_names:
1613-
if name in samples:
1614-
inputs[name] = samples[name][i]
1615-
else:
1616-
inputs[name] = spec.test_point[name]
1617-
outputs = spec.compute_fn(inputs)
1618-
if not isinstance(outputs, list | tuple):
1619-
outputs = [outputs]
1620-
for name, value in zip(spec.computed_var_names, outputs):
1621-
computed[name].append(np.asarray(value))
1622-
1623-
for name, values in computed.items():
1624-
samples[name] = np.stack(values)
1625-
16261594
inputs = property(lambda self: self.collect("input"))
16271595
symbolic_randoms = property(lambda self: self.collect("symbolic_random"))
16281596

@@ -1958,13 +1926,11 @@ def sample(
19581926
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
19591927
samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed)
19601928
spec = self._build_trace_spec(model, samples)
1961-
self._augment_samples_with_computed(model, samples, spec, draws)
1962-
if spec.computed_var_names:
1963-
spec = self._build_trace_spec(model, samples)
19641929

19651930
from collections import OrderedDict
19661931

19671932
default_point = model.initial_point()
1933+
value_var_names = [var.name for var in model.value_vars]
19681934
points = (
19691935
OrderedDict(
19701936
(
@@ -1973,7 +1939,7 @@ def sample(
19731939
if name in samples and len(samples[name]) > i
19741940
else np.asarray(spec.test_point.get(name, default_point[name])),
19751941
)
1976-
for name in spec.value_var_names
1942+
for name in value_var_names
19771943
)
19781944
for i in range(draws)
19791945
)

0 commit comments

Comments
 (0)