diff --git a/pymc/model/core.py b/pymc/model/core.py index 69e1fbed72..6f286ee9c8 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1355,7 +1355,7 @@ def make_obs_var( elif not isinstance(data, Variable): data = pt.as_tensor_variable(data, name=name) - if total_size: + if total_size is not None: from pymc.variational.minibatch_rv import create_minibatch_rv rv_var = create_minibatch_rv(rv_var, total_size) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 877814cd61..fcf42fdf8c 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -13,11 +13,9 @@ # limitations under the License. from collections.abc import Sequence -from pytensor import Variable, clone_replace +from pytensor import Variable from pytensor.graph import ancestors -from pytensor.graph.fg import FunctionGraph -from pymc.data import MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, @@ -60,25 +58,3 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l else: vars_seq = (vars,) return [model[var] if isinstance(var, str) else var for var in vars_seq] - - -def remove_minibatched_nodes(model: Model) -> Model: - """Remove all uses of pm.Minibatch in the Model.""" - fgraph, _ = fgraph_from_model(model) - - replacements = {} - for var in fgraph.apply_nodes: - if isinstance(var.op, MinibatchOp): - for inp, out in zip(var.inputs, var.outputs): - replacements[out] = inp - - old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] - # Using `rebuild_strict=False` means all coords, names, and dim information is lost - # So we need to restore it from the old fgraph - new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] - for old_out, new_out in zip(old_outs, new_outs): - new_out.name = old_out.name - fgraph = FunctionGraph(outputs=new_outs, clone=False) - fgraph._coords = old_coords # type: ignore[attr-defined] - fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] - return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/pymc/model/transform/minibatch.py b/pymc/model/transform/minibatch.py new file mode 100644 index 0000000000..364e1054c9 --- /dev/null +++ b/pymc/model/transform/minibatch.py @@ -0,0 +1,195 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence + +from pytensor import Variable +from pytensor.graph import FunctionGraph, ancestors + +from pymc import Minibatch, Model +from pymc.data import MinibatchOp +from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph +from pymc.model.transform.basic import parse_vars +from pymc.pytensorf import toposort_replace +from pymc.variational.minibatch_rv import MinibatchRandomVariable + + +def minibatch_model( + model: Model, + *, + batch_size: int, + minibatch_vars: Sequence[str | Variable] | None = None, +) -> Model: + """Create a minibatch version of the given Model. + + Replaces minibatch_vars data containers with Minibatch views and rescales the logp of dependent observed variables. + + .. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid. + + .. warning:: When minibatch_vars are not specified, all non-scalar data variables will be minibatch. This can be incorrect! + + Parameters + ---------- + model : Model + The original model to transform. + batch_size : int + The minibatch size to use. + minibatch_vars : Sequence of Variable or string, optional + Data variables to convert to minibatch. If None, all non scalar data variables will be minibatched. + + Returns + ------- + Model + A new Model with the specified data variables replaced by Minibatch views and dependent observed RVs adjusted accordingly. + + Raises + ------ + ValueError + If any of the specified variables cannot be minibatched (e.g., scalar variables or variables with static leading dimensions), or if dependent variables are Potentials / Unobserved RVs. + + Examples + -------- + .. code-block:: python + + import numpy as np + import pymc as pm + from pymc.model.transform.minibatch import minibatch_model + + with pm.Model() as m: + obs_data = pm.Data("obs_data", np.random.normal(size=(100,))) + X_data = pm.Data("X_data", np.random.normal(size=(100, 4))) + beta = pm.Normal("beta", mu=np.pi, dims="feature") + + mu = X_data @ beta + y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data) + + with minibatch_model(m, batch_size=10) as mb: + pm.fit() + """ + from pymc.variational.minibatch_rv import create_minibatch_rv + + if minibatch_vars is None: + original_minibatch_vars = [ + variable for variable in model.data_vars if variable.type.ndim > 0 + ] + else: + original_minibatch_vars = parse_vars(model, minibatch_vars) + for variable in original_minibatch_vars: + if variable.type.ndim == 0: + raise ValueError( + f"Cannot minibatch {variable.name} because it is a scalar variable." + ) + + # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed + # shape, but mu from data --> y cannot be minibatched because of sigma. + + fgraph, memo = fgraph_from_model(model, inlined_views=True) + + pre_minibatch_vars = [memo[var] for var in original_minibatch_vars] + minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size) + + # Replace uses of the specified data variables with Minibatch variables + # We need a two-step clone because FunctionGraph can only mutate one variable at a time + # and when there are multiple vars to minibatch you end up replacing the same variable twice recursively + # exampre: out = x + y + # goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)] + # replace x first we get: out = Minibatch(x, y).0 + y + # then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1 + # The second replacement of y ends up creating a circular dependency + pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars) + dummy_to_minibatch_var = tuple( + (dummy, minibatch_var) + for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars) + ) + + # Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves, + # So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs + other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars] + minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False) + minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined] + minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] + toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy) + toposort_replace(minibatch_fgraph, dummy_to_minibatch_var, rebuild=True) + + # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs + dependent_replacements = [] + total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...) + vars_to_minibatch_set = set(pre_minibatch_vars) + for model_var in minibatch_fgraph.outputs: + if not (set(ancestors([model_var])) & vars_to_minibatch_set): + continue + if not isinstance(model_var.owner.op, ModelObservedRV): + raise ValueError( + "Minibatching only supports observed RVs depending on minibatched variables. " + f"Found dependent unobserved variable: {model_var.name}." + ) + # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim + # And conversely other variables do not have that dim + observed_rv = model_var.owner.inputs[0] + minibatch_rv = create_minibatch_rv(observed_rv, total_size=total_size) + dependent_replacements.append((observed_rv, minibatch_rv)) + + toposort_replace(minibatch_fgraph, dependent_replacements, rebuild=True) + + # Finally reintroduce the original data variable outputs + for pre_minibatch_var in pre_minibatch_vars: + minibatch_fgraph.add_output(pre_minibatch_var) + + return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True) + + +def remove_minibatch(model: Model) -> Model: + """Remove all uses of Minibatch data and random variables from the Model. + + Parameters + ---------- + model : Model + The original model to transform. + + Returns + ------- + Model + A new Model with all Minibatch data variables and MinibatchRVs replaced by their original counterparts. + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc.model.transform.minibatch import undo_minibatch + + with pm.Model() as mb: + X_data = pm.Data("X_data", [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + obs_data = pm.Data("obs_data", [1, 2, 3, 4, 5]) + minibatch_X_data, minibatch_obs_data = pm.Minibatch(X_data, obs_data, batch_size=3) + + beta = pm.Normal("beta", shape=(2,)) + mu = minibatch_X_data @ beta + y = pm.Normal("y", mu=mu, sigma=1, observed=minibatch_obs_data, total_size=(5,)) + + with undo_minibatch(mb) as m: + idata = pm.sample_prior_predictive() + assert idata.prior["y"].shape[-1] == 5 # Original data size restored + + """ + fgraph, _ = fgraph_from_model(model) + + replacements = [] + for node in fgraph.apply_nodes: + if isinstance(node.op, MinibatchOp): + replacements.extend(zip(node.outputs[:-1], node.inputs[:-1])) + elif isinstance(node.op, MinibatchRandomVariable): + replacements.append((node.outputs[0], node.inputs[0])) + + toposort_replace(fgraph, replacements, rebuild=True) + return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index d7e097f6dc..e320a93bf5 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -13,7 +13,9 @@ # limitations under the License. import warnings +from collections import deque from collections.abc import Iterable, Sequence +from itertools import chain from typing import cast import numpy as np @@ -1032,12 +1034,105 @@ def as_symbolic_string(x, **kwargs): return StringConstant(stringtype, x) +def _replace_rebuild( + fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], **kwargs +) -> FunctionGraph: + """Replace variables and rebuild dependent graph if needed. + + Rebuilding allows for replacements that change the semantics of the graph + (different types), which may not be possible for all Ops. + """ + fg_clients = fgraph.clients + fg_variables = fgraph.variables + + def get_client_nodes(vars) -> set[Apply]: + # Start with the immediate clients of vars + nodes = set() + d = list(chain.from_iterable(fg_clients[var] for var in vars if var in fg_variables)) + while d: + node, _ = d.pop() + if node in nodes or isinstance(node.op, Output): + continue + nodes.add(node) + # Keep walking to the successor clients + d.extend(chain.from_iterable(fg_clients[out] for out in node.outputs)) + return nodes + + repl_dict = dict(replacements) + root_nodes = {var.owner for var in repl_dict.keys()} + + # Build sorted queue with all nodes that depend on replaced variables + topo_order = {node: order for order, node in enumerate(fgraph.toposort())} + d = deque( + sorted( + get_client_nodes(repl_dict.keys()), + key=lambda node: topo_order[node], + ) + ) + while d: + node = d.popleft() + if node in root_nodes: + continue + + new_inputs = [repl_dict.get(i, i) for i in node.inputs] + if new_inputs == node.inputs: + continue + + # We need to remake the node if: + # 1. The output type depends on an input value + # 2. Any of the input type changed + if getattr(node.op, "_output_type_depends_on_input_value", False): + remake_node = True + else: + remake_node = any( + inp.type != new_inp.type for inp, new_inp in zip(node.inputs, new_inputs) + ) + + if remake_node: + new_node = node.clone_with_new_inputs(new_inputs, strict=False) + fgraph.import_node(new_node, import_missing=True) + + # We are not always allowed to call `fgraph.replace_all` because the output types may be incompatible + # We will keep the changes in repl_dict until we can replace a node without remaking it, + # or we arrive to the end of the graph, in which case we need to replace the FunctionGraph output + for out, new_out in zip(node.outputs, new_node.outputs): + new_out.name = out.name + repl_dict[out] = new_out + else: + fgraph.replace_all(tuple(zip(node.inputs, new_inputs)), import_missing=True) + + # If the FunctionGraph outputs themselves were rebuilt we need to handle them + for i, (new_output, old_output) in enumerate( + zip( + (repl_dict.get(out, out) for out in fgraph.outputs), + fgraph.outputs, + ) + ): + if new_output is old_output: + continue + fgraph.outputs[i] = new_output + fgraph.import_var(new_output, import_missing=True) + fgraph.clients[new_output] = [ + # Output variables have a special Output Op client + # We need to transfer it to the new output. + # Any other uses of this output variable will already have been substituted in the loop above, + # or are part of other outputs we will subsitute next + (cl.op.make_node(new_output), idx) if isinstance(cl.op, Output) else (cl, idx) + for cl, idx in fgraph.clients[old_output] + ] + return fgraph + + def toposort_replace( fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False, + rebuild: bool = False, ) -> None: """Replace multiple variables in place in topological order.""" + if rebuild and reverse: + raise NotImplementedError("reverse rebuild not yet supported") + fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())} fgraph_toposort[None] = -1 # Variables without owner are not in the toposort sorted_replacements = sorted( @@ -1045,7 +1140,21 @@ def toposort_replace( key=lambda pair: fgraph_toposort[pair[0].owner], reverse=reverse, ) - fgraph.replace_all(sorted_replacements, import_missing=True) + + if rebuild: + if len(replacements) > 1: + # In this case we need to modify the replacements recursively with each other + sorted_replacements = [list(pairs) for pairs in sorted_replacements] + for i in range(1, len(replacements)): + temp_fgraph = FunctionGraph( + outputs=[repl for _, repl in sorted_replacements[i:]], + clone=False, + ) + _replace_rebuild(temp_fgraph, replacements=sorted_replacements[:i]) + sorted_replacements[i][1] = temp_fgraph.outputs[0] + _replace_rebuild(fgraph, sorted_replacements) + else: + fgraph.replace_all(sorted_replacements, import_missing=True) def normalize_rng_param(rng: None | Variable) -> Variable: diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 34d30cfa50..133165b279 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -19,6 +19,7 @@ from pytensor import Variable, config from pytensor.graph import Apply, Op from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable +from pytensor.tensor.type_other import NoneTypeT from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import logp @@ -33,7 +34,9 @@ class MinibatchRandomVariable(MeasurableOp, Op): def make_node(self, rv, *total_size): rv = as_tensor_variable(rv) total_size = [ - as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst + t + if isinstance(t, Variable) + else (NoneConst if t is None else as_tensor_variable(t, dtype="int64", ndim=0)) for t in total_size ] assert len(total_size) == rv.ndim @@ -55,45 +58,67 @@ def perform(self, node, inputs, output_storage): def create_minibatch_rv( rv: TensorVariable, - total_size: int | None | Sequence[int | EllipsisType | None], + total_size: int | TensorVariable | Sequence[int | TensorVariable | EllipsisType | None], ) -> TensorVariable: """Create variable whose logp is rescaled by total_size.""" + rv_ndim_supp = rv.owner.op.ndim_supp + if isinstance(total_size, int): - if rv.ndim <= 1: - total_size = [total_size] + total_size = (total_size, *([None] * rv_ndim_supp)) + elif isinstance(total_size, TensorVariable): + if total_size.type.ndim == 0: + total_size = (total_size, *([None] * rv_ndim_supp)) + elif total_size.type.ndim == 1: + total_size = tuple(total_size) else: - missing_ndims = rv.ndim - 1 - total_size = [total_size] + [None] * missing_ndims - elif isinstance(total_size, list | tuple): - total_size = list(total_size) - if Ellipsis in total_size: - # Replace Ellipsis by None - if total_size.count(Ellipsis) > 1: - raise ValueError("Only one Ellipsis can be present in total_size") - sep = total_size.index(Ellipsis) - begin = total_size[:sep] - end = total_size[sep + 1 :] - missing_ndims = max((rv.ndim - len(begin) - len(end), 0)) - total_size = begin + [None] * missing_ndims + end - if len(total_size) > rv.ndim: - raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}") - else: - raise TypeError(f"Invalid type for total_size: {total_size}") - - return cast(TensorVariable, minibatch_rv(rv, *total_size)) - - -def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable: + raise ValueError( + f"Total size must be a 0d or 1d vector got {total_size} with {total_size.type.ndim} dimensions" + ) + + if not isinstance(total_size, list | tuple): + raise ValueError(f"Invalid type for total_size {total_size}: {type(total_size)}") + + if Ellipsis in total_size: + # Replace Ellipsis by None + if total_size.count(Ellipsis) > 1: + raise ValueError("Only one Ellipsis can be present in total_size") + sep = total_size.index(Ellipsis) + begin = total_size[:sep] + end = total_size[sep + 1 :] + missing_ndims = max((rv_ndim_supp - len(begin) - len(end), 0)) + total_size = (*begin, *([None] * missing_ndims), *end) + + if (len(total_size) - rv_ndim_supp) not in (0, 1): + raise ValueError( + f"Length of total_size {total_size} not compatble with ndim_supp of RV {rv}, " + f"got {len(total_size)} but must be {rv_ndim_supp} or {rv_ndim_supp - 1}" + ) + + out = minibatch_rv(rv, *total_size) + assert isinstance(out.owner.op, MinibatchRandomVariable) + return cast(TensorVariable, out) + + +def get_scaling( + total_size: Sequence[TensorVariable], shape: TensorVariable | Sequence[TensorVariable] +) -> TensorVariable: """Get scaling constant for logp.""" # mypy doesn't understand we can convert a shape TensorVariable into a tuple - shape = tuple(shape) # type: ignore[assignment] + shape = tuple(shape) + + if len(total_size) == (len(shape) - 1): + # This happens when RV has no batch dimensions + # In that case the total_size corresponds to a dummy shape of 1 + total_size = (1, *total_size) + + assert len(shape) == len(total_size) - # Scalar RV - if len(shape) == 0: # type: ignore[arg-type] - coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0 - else: - coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)] - coef = pt.prod(coefs) + coefs = [ + size / dim_length + for size, dim_length in zip(total_size, shape) + if not isinstance(size.type, NoneTypeT) + ] + coef = pt.prod(coefs) if len(coefs) > 1 else coefs[0] return pt.cast(coef, dtype=config.floatX) @@ -102,4 +127,6 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor def minibatch_rv_logprob(op, values, *inputs, **kwargs): [value] = values rv, *total_size = inputs - return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape) + raw_logp = logp(rv, value, **kwargs) + scaled_logp = raw_logp * get_scaling(total_size, raw_logp.shape) + return scaled_logp diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 856fbf0b2b..4042ee0da3 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -11,41 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - -import pymc as pm - -from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes +from pymc import Data, Model +from pymc.distributions import Normal +from pymc.model.transform.basic import ( + prune_vars_detached_from_observed, +) def test_prune_vars_detached_from_observed(): - with pm.Model() as m: - obs_data = pm.Data("obs_data", 0) - a0 = pm.Data("a0", 0) - a1 = pm.Normal("a1", a0) - a2 = pm.Normal("a2", a1) - pm.Normal("obs", a2, observed=obs_data) + with Model() as m: + obs_data = Data("obs_data", 0) + a0 = Data("a0", 0) + a1 = Normal("a1", a0) + a2 = Normal("a2", a1) + Normal("obs", a2, observed=obs_data) - d0 = pm.Data("d0", 0) - d1 = pm.Normal("d1", d0) + d0 = Data("d0", 0) + d1 = Normal("d1", d0) assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} pruned_m = prune_vars_detached_from_observed(m) assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} - - -def test_remove_minibatches(): - data_size = 100 - data = np.zeros((data_size,)) - batch_size = 10 - with pm.Model(coords={"d": range(5)}) as m1: - mb = pm.Minibatch(data, batch_size=batch_size) - mu = pm.Normal("mu", dims="d") - x = pm.Normal("x") - y = pm.Normal("y", x, observed=mb, total_size=100) - - m2 = remove_minibatched_nodes(m1) - assert m1.y.shape[0].eval() == batch_size - assert m2.y.shape[0].eval() == data_size - assert m1.coords == m2.coords - assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval() diff --git a/tests/model/transform/test_minibatch.py b/tests/model/transform/test_minibatch.py new file mode 100644 index 0000000000..20b9c0a2a7 --- /dev/null +++ b/tests/model/transform/test_minibatch.py @@ -0,0 +1,122 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pymc.data import Data, Minibatch +from pymc.distributions import HalfNormal, Normal +from pymc.model.core import Model +from pymc.model.transform.minibatch import minibatch_model, remove_minibatch +from pymc.variational.minibatch_rv import MinibatchRandomVariable + + +def test_minibatch_model(): + data_size = 100 + n_features = 4 + + obs_data_np = np.random.normal(size=(data_size,)) + X_data_np = np.random.normal(size=(data_size, n_features)) + + with Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m: + obs_data = Data("obs_data", obs_data_np, dims=["data_dim"]) + X_data = Data("X_data", X_data_np, dims=["data_dim", "feature"]) + beta = Normal("beta", mu=np.pi, dims="feature") + + mu = X_data @ beta + y = Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") + + with Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as ref_m: + obs_data = Data("obs_data", obs_data_np, dims=["data_dim"]) + X_data = Data("X_data", X_data_np, dims=["data_dim", "feature"]) + minibatch_obs_data, minibatch_X_data = Minibatch(obs_data, X_data, batch_size=10) + beta = Normal("beta", mu=np.pi, dims="feature") + mu = minibatch_X_data @ beta + y = Normal( + "y", + mu=mu, + sigma=1, + observed=minibatch_obs_data, + dims="data_dim", + total_size=(obs_data.shape[0], ...), + ) + + mb = minibatch_model(m, batch_size=10) + mb_logp_fn = mb.compile_logp(random_seed=42) + ref_mb_logp_fn = ref_m.compile_logp(random_seed=42) + ip = mb.initial_point() + + mb_res1 = mb_logp_fn(ip) + ref_mb_res1 = ref_mb_logp_fn(ip) + np.testing.assert_allclose(mb_res1, ref_mb_res1) + mb_res2 = mb_logp_fn(ip) + # Minibatch should give different results on each call + assert mb_res1 != mb_res2 + ref_mb_res2 = ref_mb_logp_fn(ip) + np.testing.assert_allclose(mb_res2, ref_mb_res2) + + +def test_remove_minibatch(): + data_size = 100 + n_features = 5 + batch_size = 10 + with Model(coords={"d": range(n_features)}) as mb: + X_data = Data("X_data", np.random.normal(size=(data_size, n_features))) + obs_data = Data("obs_data", [1, 2, 3, 4, 5]) + minibatch_X_data, minibatch_obs_data = Minibatch(X_data, obs_data, batch_size=batch_size) + + beta = Normal("beta", dims=("d",)) + mu = minibatch_X_data @ beta + sigma = HalfNormal("sigma") + y = Normal("y", mu=mu, sigma=sigma, observed=minibatch_obs_data, total_size=X_data.shape[0]) + + m = remove_minibatch(mb) + assert isinstance(mb.y.owner.op, MinibatchRandomVariable) + assert tuple(mb.y.shape).eval() == (batch_size,) + assert isinstance(m.y.owner.op, Normal) + assert tuple(m.y.shape.eval()) == (data_size,) + assert mb.coords == m.coords + assert mb.dim_lengths["d"].eval() == m.dim_lengths["d"].eval() + + +@pytest.mark.parametrize("static_shape", (True, False)) +def test_minibatch_transform_roundtrip(static_shape): + data_size = 100 + n_features = 4 + with Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m: + obs_data = Data( + "obs_data", + np.random.normal(size=(data_size,)), + dims=["data_dim"], + shape=(data_size if static_shape else None,), + ) + X_data = Data( + "X_data", + np.random.normal(size=(data_size, n_features)), + dims=["data_dim", "feature"], + shape=(data_size if static_shape else None, n_features), + ) + beta = Normal("beta", mu=np.pi, dims="feature") + + mu = X_data @ beta + y = Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") + + m_again = remove_minibatch(minibatch_model(m, batch_size=10)) + m_again_logp_fn = m_again.compile_logp(random_seed=42) + m_logp_fn = m_again.compile_logp(random_seed=42) + ip = m_again.initial_point() + m_again_res = m_again_logp_fn(ip) + m_res = m_logp_fn(ip) + np.testing.assert_allclose(m_again_res, m_res) + # Check that repeated calls give the same result (no more minibatching) + np.testing.assert_allclose(m_again_res, m_again_logp_fn(ip)) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index d172c61a4d..5d4e7d2d3d 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,6 +24,7 @@ from pytensor import scan, shared from pytensor.compile import UnusedInputError from pytensor.compile.builders import OpFromGraph +from pytensor.graph import FunctionGraph from pytensor.graph.basic import Variable, equal_computations from pytensor.tensor.subtensor import AdvancedIncSubtensor @@ -46,6 +47,7 @@ replace_rng_nodes, replace_vars_in_graphs, reseed_rngs, + toposort_replace, ) from pymc.vartypes import int_types @@ -785,3 +787,68 @@ def test_pickle_point_func(): np.testing.assert_allclose( point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]}) ) + + +class TestToposortReplace: + @pytest.mark.parametrize("compatible_type", (True, False)) + @pytest.mark.parametrize("num_replacements", (1, 2)) + @pytest.mark.parametrize("rebuild", (True, False)) + def test_horizontal_dependency(self, compatible_type, num_replacements, rebuild): + x = pt.vector("x", shape=(5,)) + y = pt.vector("y", shape=(5,)) + + out1 = pt.exp(x + y) + pt.log(x + y) + out2 = pt.cos(out1) + + new_shape = (5,) if compatible_type else (10,) + new_x = pt.vector("new_x", shape=new_shape) + new_y = pt.vector("new_y", shape=new_shape) + if num_replacements == 1: + replacements = [(y, new_y)] + else: + replacements = [(x, new_x), (y, new_y)] + + fg = FunctionGraph([x, y], [out1, out2], clone=False) + + # If types are incompatible, and we don't rebuild or only replace one of the variables, + # The function should fail + if not compatible_type and (not rebuild or num_replacements == 1): + with pytest.raises((TypeError, ValueError)): + toposort_replace(fg, replacements, rebuild=rebuild) + return + toposort_replace(fg, replacements, rebuild=rebuild) + + if num_replacements == 1: + expected_out1 = pt.exp(x + new_y) + pt.log(x + new_y) + else: + expected_out1 = pt.exp(new_x + new_y) + pt.log(new_x + new_y) + expected_out2 = pt.cos(expected_out1) + assert equal_computations(fg.outputs, [expected_out1, expected_out2]) + + @pytest.mark.parametrize("compatible_type", (True, False)) + @pytest.mark.parametrize("num_replacements", (2, 3)) + @pytest.mark.parametrize("rebuild", (True, False)) + def test_vertical_dependency(self, compatible_type, num_replacements, rebuild): + x = pt.vector("x", shape=(5,)) + a1 = pt.exp(x) + a2 = pt.log(a1) + out = a1 + a2 + + new_x = pt.vector("new_x", shape=(5 if compatible_type else 10,)) + if num_replacements == 2: + replacements = [(x, new_x), (a1, pt.cos(a1)), (a2, pt.sin(a2 + 5))] + else: + replacements = [(a1, pt.cos(pt.exp(new_x))), (a2, pt.sin(a2 + 5))] + + fg = FunctionGraph([x], [out], clone=False) + + if not compatible_type and not rebuild: + with pytest.raises(TypeError): + toposort_replace(fg, replacements, rebuild=rebuild) + return + toposort_replace(fg, replacements, rebuild=rebuild) + + expected_a1 = pt.cos(pt.exp(new_x)) + expected_a2 = pt.sin(pt.log(expected_a1) + 5) + expected_out = expected_a1 + expected_a2 + assert equal_computations(fg.outputs, [expected_out])