Skip to content

Commit adcdf90

Browse files
committed
Make tests pass
1 parent a044e76 commit adcdf90

File tree

5 files changed

+27
-28
lines changed

5 files changed

+27
-28
lines changed

pymc/sampling/mcmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray:
16481648
compile_kwargs=compile_kwargs,
16491649
)
16501650
approx_sample = approx.sample(
1651-
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
1651+
draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False
16521652
)
16531653
initial_points = [approx_sample[i] for i in range(chains)]
16541654
std_apoint = approx.std.eval()
@@ -1672,7 +1672,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray:
16721672
compile_kwargs=compile_kwargs,
16731673
)
16741674
approx_sample = approx.sample(
1675-
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
1675+
draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False
16761676
)
16771677
initial_points = [approx_sample[i] for i in range(chains)]
16781678
cov = approx.std.eval() ** 2
@@ -1690,7 +1690,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray:
16901690
compile_kwargs=compile_kwargs,
16911691
)
16921692
approx_sample = approx.sample(
1693-
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
1693+
draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False
16941694
)
16951695
initial_points = [approx_sample[i] for i in range(chains)]
16961696
cov = approx.std.eval() ** 2

pymc/variational/operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(self, approx, temperature=1):
142142

143143
def apply(self, f):
144144
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
145-
if _known_scan_ignored_inputs([self.approx.model.logp()]):
145+
if _known_scan_ignored_inputs([self.approx._model.logp()]):
146146
raise NotImplementedInference(
147147
"SVGD does not currently support Minibatch or Simulator RV"
148148
)

pymc/variational/opvi.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,7 @@ def symbolic_normalizing_constant(self):
12891289
obs.owner.inputs[1:],
12901290
constant_fold([obs.owner.inputs[0].shape], raise_not_constant=False),
12911291
)
1292-
for obs in self.model.observed_RVs
1292+
for obs in self._model.observed_RVs
12931293
if isinstance(obs.owner.op, MinibatchRandomVariable)
12941294
]
12951295
)
@@ -1315,7 +1315,7 @@ def logq_norm(self):
13151315
def _sized_symbolic_varlogp_and_datalogp(self):
13161316
"""*Dev* - computes sampled prior term from model via `pytensor.scan`."""
13171317
varlogp_s, datalogp_s = self.symbolic_sample_over_posterior(
1318-
[self.model.varlogp, self.model.datalogp]
1318+
[self._model.varlogp, self._model.datalogp]
13191319
)
13201320
return varlogp_s, datalogp_s # both shape (s,)
13211321

@@ -1352,7 +1352,7 @@ def datalogp(self):
13521352
@node_property
13531353
def _single_symbolic_varlogp_and_datalogp(self):
13541354
"""*Dev* - computes sampled prior term from model via `pytensor.scan`."""
1355-
varlogp, datalogp = self.symbolic_single_sample([self.model.varlogp, self.model.datalogp])
1355+
varlogp, datalogp = self.symbolic_single_sample([self._model.varlogp, self._model.datalogp])
13561356
return varlogp, datalogp
13571357

13581358
@node_property
@@ -1491,14 +1491,12 @@ def get_optimization_replacements(self, s, d):
14911491
return repl
14921492

14931493
@pytensor.config.change_flags(compute_test_value="off")
1494-
def sample_node(self, node, model=None, size=None, deterministic=False, more_replacements=None):
1494+
def sample_node(self, node, size=None, deterministic=False, more_replacements=None):
14951495
"""Sample given node or nodes over shared posterior.
14961496
14971497
Parameters
14981498
----------
14991499
node: PyTensor Variables (or PyTensor expressions)
1500-
model : Model (optional if in ``with`` context
1501-
Model to be used to generate samples.
15021500
size: None or scalar
15031501
number of samples
15041502
more_replacements: `dict`
@@ -1513,7 +1511,7 @@ def sample_node(self, node, model=None, size=None, deterministic=False, more_rep
15131511
"""
15141512
node_in = node
15151513

1516-
model = modelcontext(model)
1514+
model = self._model
15171515

15181516
if more_replacements:
15191517
node = graph_replace(node, more_replacements, strict=False)
@@ -1552,18 +1550,18 @@ def vars_names(vs):
15521550
return found
15531551

15541552
@node_property
1555-
def sample_dict_fn(self, model=None):
1553+
def sample_dict_fn(self):
15561554
s = pt.iscalar()
15571555

1558-
model = modelcontext(model)
1556+
def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None):
1557+
model = modelcontext(model)
15591558

1560-
names = [model.rvs_to_values[v].name for v in model.free_RVs]
1561-
sampled = [self.rslice(name, model) for name in names]
1562-
sampled = self.set_size_and_deterministic(sampled, s, 0)
1563-
sample_fn = compile([s], sampled)
1564-
rng_nodes = find_rng_nodes(sampled)
1559+
names = [model.rvs_to_values[v].name for v in model.free_RVs]
1560+
sampled = [self.rslice(name, model) for name in names]
1561+
sampled = self.set_size_and_deterministic(sampled, s, 0)
1562+
sample_fn = compile([s], sampled)
1563+
rng_nodes = find_rng_nodes(sampled)
15651564

1566-
def inner(draws=100, *, random_seed: SeedSequenceSeed = None):
15671565
if random_seed is not None:
15681566
reseed_rngs(rng_nodes, random_seed)
15691567
_samples = sample_fn(draws)

tests/variational/test_inference.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_fit_with_nans(score):
4141
mean = inp * coef
4242
pm.Normal("y", mean, 0.1, observed=y)
4343
with pytest.raises(FloatingPointError) as e:
44-
advi = pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan")))
44+
pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan")))
4545

4646

4747
@pytest.fixture(scope="module", params=[True, False], ids=["mini", "full"])
@@ -174,8 +174,8 @@ def fit_kwargs(inference, use_minibatch):
174174
return _select[(type(inference), key)]
175175

176176

177-
def test_fit_oo(inference, fit_kwargs, simple_model_data):
178-
trace = inference.fit(**fit_kwargs).sample(10000)
177+
def test_fit_oo(simple_model, inference, fit_kwargs, simple_model_data):
178+
trace = inference.fit(**fit_kwargs).sample(10000, model=simple_model)
179179
mu_post = simple_model_data["mu_post"]
180180
d = simple_model_data["d"]
181181
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05)
@@ -202,7 +202,8 @@ def test_fit_start(inference_spec, simple_model):
202202
inference = inference_spec(**kw)
203203

204204
try:
205-
trace = inference.fit(n=0).sample(10000)
205+
with simple_model:
206+
trace = inference.fit(n=0).sample(10000)
206207
except NotImplementedInference as e:
207208
pytest.skip(str(e))
208209

@@ -269,7 +270,7 @@ def binomial_model_inference(binomial_model, inference_spec):
269270
def test_replacements(binomial_model_inference):
270271
d = pytensor.shared(1)
271272
approx = binomial_model_inference.approx
272-
p = approx.model.p
273+
p = approx._model.p
273274
p_t = p**3
274275
p_s = approx.sample_node(p_t)
275276
assert not any(
@@ -309,7 +310,7 @@ def test_sample_replacements(binomial_model_inference):
309310
i = pt.iscalar()
310311
i.tag.test_value = 1
311312
approx = binomial_model_inference.approx
312-
p = approx.model.p
313+
p = approx._model.p
313314
p_t = p**3
314315
p_s = approx.sample_node(p_t, size=100)
315316
if pytensor.config.compute_test_value != "off":

tests/variational/test_opvi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,19 +213,19 @@ def test_pickle_approx(three_var_approx):
213213

214214
dump = cloudpickle.dumps(three_var_approx)
215215
new = cloudpickle.loads(dump)
216-
assert new.sample(1)
216+
assert new.sample(1, model=new._model)
217217

218218

219219
def test_pickle_single_group(three_var_approx_single_group_mf):
220220
import cloudpickle
221221

222222
dump = cloudpickle.dumps(three_var_approx_single_group_mf)
223223
new = cloudpickle.loads(dump)
224-
assert new.sample(1)
224+
assert new.sample(1, model=new._model)
225225

226226

227227
def test_sample_simple(three_var_approx):
228-
trace = three_var_approx.sample(100, return_inferencedata=False)
228+
trace = three_var_approx.sample(100, model=three_var_approx._model, return_inferencedata=False)
229229
assert set(trace.varnames) == {"one", "one_log__", "three", "two"}
230230
assert len(trace) == 100
231231
assert trace[0]["one"].shape == (10, 2)

0 commit comments

Comments
 (0)