Skip to content

Commit 6a6b591

Browse files
committed
Vibe coded
1 parent adcdf90 commit 6a6b591

File tree

5 files changed

+722
-165
lines changed

5 files changed

+722
-165
lines changed

pymc/variational/approximations.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def create_shared_params(self, start=None, start_sigma=None):
8585
# by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
8686
# future code changes that break this assumption.
8787
start = self._prepare_start(start)
88+
# Ensure start is a 1D array and matches ddim
89+
start = np.asarray(start).flatten()
90+
if start.size != self.ddim:
91+
raise ValueError(
92+
f"Start array size mismatch: got {start.size}, expected {self.ddim}. "
93+
f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}"
94+
)
8895
rho1 = np.zeros((self.ddim,))
8996

9097
if start_sigma is not None:
@@ -139,6 +146,13 @@ def __init_group__(self, group):
139146

140147
def create_shared_params(self, start=None):
141148
start = self._prepare_start(start)
149+
# Ensure start is a 1D array and matches ddim
150+
start = np.asarray(start).flatten()
151+
if start.size != self.ddim:
152+
raise ValueError(
153+
f"Start array size mismatch: got {start.size}, expected {self.ddim}. "
154+
f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}"
155+
)
142156
n = self.ddim
143157
L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX)
144158
return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")}
@@ -233,17 +247,19 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
233247
return {"histogram": pytensor.shared(pm.floatX(histogram), "histogram")}
234248

235249
def _check_trace(self):
250+
from pymc.model import modelcontext
251+
236252
trace = self._kwargs.get("trace", None)
237253
if isinstance(trace, InferenceData):
238254
raise NotImplementedError(
239255
"The `Empirical` approximation does not yet support `InferenceData` inputs."
240256
" Pass `pm.sample(return_inferencedata=False)` to get a `MultiTrace` to use with `Empirical`."
241257
" Please help us to refactor: https://github.com/pymc-devs/pymc/issues/5884"
242258
)
243-
elif trace is not None and not all(
244-
self.model.rvs_to_values[var].name in trace.varnames for var in self.group
245-
):
246-
raise ValueError("trace has not all free RVs in the group")
259+
elif trace is not None:
260+
model = modelcontext(None)
261+
if not all(model.rvs_to_values[var].name in trace.varnames for var in self.group):
262+
raise ValueError("trace has not all free RVs in the group")
247263

248264
def randidx(self, size=None):
249265
if size is None:

pymc/variational/operators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pymc as pm
2121

22+
from pymc.model import modelcontext
2223
from pymc.variational import opvi
2324
from pymc.variational.opvi import (
2425
NotImplementedInference,
@@ -142,7 +143,8 @@ def __init__(self, approx, temperature=1):
142143

143144
def apply(self, f):
144145
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
145-
if _known_scan_ignored_inputs([self.approx._model.logp()]):
146+
model = modelcontext(None)
147+
if _known_scan_ignored_inputs([model.logp()]):
146148
raise NotImplementedInference(
147149
"SVGD does not currently support Minibatch or Simulator RV"
148150
)

0 commit comments

Comments
 (0)