From c4b0d5726a1659180f8c26a9c1dd70ca1d1d5823 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 15 May 2025 20:24:52 +0800 Subject: [PATCH 1/7] initial PR --- pymc/model/transform/basic.py | 47 ++++++++++++++++++++++++++++- tests/model/transform/test_basic.py | 28 ++++++++++++++++- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 877814cd61..558275d97f 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -14,17 +14,21 @@ from collections.abc import Sequence from pytensor import Variable, clone_replace +from pytensor.compile import SharedVariable from pytensor.graph import ancestors from pytensor.graph.fg import FunctionGraph -from pymc.data import MinibatchOp +from pymc.data import Minibatch, MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, ModelVar, + extract_dims, fgraph_from_model, model_from_fgraph, + model_observed_rv, ) +from pymc.pytensorf import toposort_replace ModelVariable = Variable | str @@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l return [model[var] if isinstance(var, str) else var for var in vars_seq] +def model_to_minibatch(model: Model, batch_size: int) -> Model: + """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" + from pymc.variational.minibatch_rv import create_minibatch_rv + + fgraph, memo = fgraph_from_model(model, inlined_views=True) + + # obs_rvs, data_vars = model.rvs_to_values.items() + + data_vars = [ + memo[datum].owner.inputs[0] + for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) + if isinstance(datum, SharedVariable) + ] + + minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) + replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} + assert 0 + # Add total_size to all observed RVs + total_size = data_vars[0].get_value().shape[0] + for obs_var in model.observed_RVs: + model_var = memo[obs_var] + var = model_var.owner.inputs[0] + var.name = model_var.name + dims = extract_dims(model_var) + + new_rv = create_minibatch_rv(var, total_size=total_size) + new_rv.name = var.name + + replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims) + + # old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths + toposort_replace(fgraph, tuple(replacements.items())) + # new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] + + # fgraph = FunctionGraph(outputs=new_outs, clone=False) + # fgraph._coords = old_coords # type: ignore[attr-defined] + # fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] + + return model_from_fgraph(fgraph, mutate_fgraph=True) + + def remove_minibatched_nodes(model: Model) -> Model: """Remove all uses of pm.Minibatch in the Model.""" fgraph, _ = fgraph_from_model(model) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 856fbf0b2b..c3b33730db 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -15,7 +15,11 @@ import pymc as pm -from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes +from pymc.model.transform.basic import ( + model_to_minibatch, + prune_vars_detached_from_observed, + remove_minibatched_nodes, +) def test_prune_vars_detached_from_observed(): @@ -34,6 +38,28 @@ def test_prune_vars_detached_from_observed(): assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} +def test_model_to_minibatch(): + data_size = 100 + n_features = 4 + + obs_data = np.zeros((data_size,)) + X_data = np.random.normal(size=(data_size, n_features)) + + with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1: + obs_data = pm.Data("obs_data", obs_data, dims=["data_dim"]) + X_data = pm.Data("X_data", X_data, dims=["data_dim", "feature"]) + beta = pm.Normal("beta", dims="feature") + + mu = X_data @ beta + + y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") + + m2 = model_to_minibatch(m1, batch_size=10) + m2["y"].dprint() + + assert 0 + + def test_remove_minibatches(): data_size = 100 data = np.zeros((data_size,)) From 093a7bef643ad7db2cb1f45b4bd13b8f4d691c08 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 16 Nov 2025 13:23:48 -0600 Subject: [PATCH 2/7] Working model_to_minibatch implementation --- pymc/model/transform/basic.py | 108 ++++++++++++++++++++-------- tests/model/transform/test_basic.py | 30 ++++++-- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 558275d97f..e9ffac8731 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -14,7 +14,6 @@ from collections.abc import Sequence from pytensor import Variable, clone_replace -from pytensor.compile import SharedVariable from pytensor.graph import ancestors from pytensor.graph.fg import FunctionGraph @@ -23,10 +22,8 @@ from pymc.model.fgraph import ( ModelObservedRV, ModelVar, - extract_dims, fgraph_from_model, model_from_fgraph, - model_observed_rv, ) from pymc.pytensorf import toposort_replace @@ -66,45 +63,96 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l return [model[var] if isinstance(var, str) else var for var in vars_seq] -def model_to_minibatch(model: Model, batch_size: int) -> Model: +def model_to_minibatch( + model: Model, *, batch_size: int, vars_to_minibatch: list[str] | None = None +) -> Model: """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" from pymc.variational.minibatch_rv import create_minibatch_rv + if vars_to_minibatch is None: + vars_to_minibatch = [ + variable + for variable in model.data_vars + if (variable.type.ndim > 0) and (variable.type.shape[0] is None) + ] + else: + vars_to_minibatch = parse_vars(model, vars_to_minibatch) + for variable in vars_to_minibatch: + if variable.type.ndim == 0: + raise ValueError( + f"Cannot minibatch {variable.name} because it is a scalar variable." + ) + if variable.type.shape[0] is not None: + raise ValueError( + f"Cannot minibatch {variable.name} because its first dimension is static " + f"(size={variable.type.shape[0]})." + ) + + # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed + # shape, but mu from data --> y cannot be minibatched because of sigma. + fgraph, memo = fgraph_from_model(model, inlined_views=True) - # obs_rvs, data_vars = model.rvs_to_values.items() + cloned_vars_to_minibatch = [memo[var] for var in vars_to_minibatch] + minibatch_vars = Minibatch(*cloned_vars_to_minibatch, batch_size=batch_size) - data_vars = [ - memo[datum].owner.inputs[0] - for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) - if isinstance(datum, SharedVariable) - ] + var_to_dummy = { + var: var.type() # model_named(minibatch_var, *extract_dims(var)) + for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars) + } + dummy_to_minibatch = { + var_to_dummy[var]: minibatch_var + for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars) + } + total_size = (cloned_vars_to_minibatch[0].owner.inputs[0].shape[0], ...) - minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) - replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} - assert 0 - # Add total_size to all observed RVs - total_size = data_vars[0].get_value().shape[0] - for obs_var in model.observed_RVs: - model_var = memo[obs_var] - var = model_var.owner.inputs[0] - var.name = model_var.name - dims = extract_dims(model_var) + # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim + # (or just do this all in xtensor) + vtm_set = set(vars_to_minibatch) - new_rv = create_minibatch_rv(var, total_size=total_size) - new_rv.name = var.name + # TODO: Handle potentials, free_RVs, etc - replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims) + # Create a temporary fgraph that does not include as outputs any of the variables that will be minibatched. This + # ensures the results of this function match the outputs from a model constructed using the pm.Minibatch API. + tmp_fgraph = FunctionGraph( + outputs=[out for out in fgraph.outputs if out not in var_to_dummy.keys()], clone=False + ) - # old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths - toposort_replace(fgraph, tuple(replacements.items())) - # new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] + # All variables that will be minibatched are first replaced by dummy variables, to avoid infinite recursion during + # rewrites. The issue is that the Minibatch Op we will introduce depends on the original input variables (to get + # the shapes). That's fine in the final output, but during the intermediate rewrites this creates a circulatr + # dependency. + dummy_replacements = tuple(var_to_dummy.items()) + toposort_replace(tmp_fgraph, dummy_replacements) - # fgraph = FunctionGraph(outputs=new_outs, clone=False) - # fgraph._coords = old_coords # type: ignore[attr-defined] - # fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] + # Now we can replace the dummy variables with the actual Minibatch variables. + replacements = tuple(dummy_to_minibatch.items()) + toposort_replace(tmp_fgraph, replacements) - return model_from_fgraph(fgraph, mutate_fgraph=True) + # The last step is to replace all RVs that depend on the minibatched variables with MinibatchRVs that are aware + # of the total_size. Importantly, all of the toposort_replace calls above modify fgraph in place, so the + # model.rvs_to_values[original_rv] will already have been modified to depend on the Minibatch variables -- only + # the outer RVs need to be replaced here. + dependent_replacements = {} + + for original_rv in model.observed_RVs: + original_value_var = model.rvs_to_values[original_rv] + + if not (set(ancestors([original_rv, original_value_var])) & vtm_set): + continue + + rv = memo[original_rv].owner.inputs[0] + dependent_replacements[rv] = create_minibatch_rv(rv, total_size=total_size) + + toposort_replace(fgraph, tuple(dependent_replacements.items())) + + # FIXME: The fgraph is being rebuilt here to clean up the clients. It is not clear why they are getting messed up + # in the first place (pytensor bug, or something wrong in the above manipulations?) + new_fgraph = FunctionGraph(outputs=fgraph.outputs) + new_fgraph._coords = fgraph._coords # type: ignore[attr-defined] + new_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] + + return model_from_fgraph(new_fgraph, mutate_fgraph=True) def remove_minibatched_nodes(model: Model) -> Model: diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index c3b33730db..757e866392 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -20,6 +20,7 @@ prune_vars_detached_from_observed, remove_minibatched_nodes, ) +from pymc.testing import assert_equivalent_models def test_prune_vars_detached_from_observed(): @@ -42,22 +43,37 @@ def test_model_to_minibatch(): data_size = 100 n_features = 4 - obs_data = np.zeros((data_size,)) - X_data = np.random.normal(size=(data_size, n_features)) + obs_data_np = np.zeros((data_size,)) + X_data_np = np.random.normal(size=(data_size, n_features)) with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1: - obs_data = pm.Data("obs_data", obs_data, dims=["data_dim"]) - X_data = pm.Data("X_data", X_data, dims=["data_dim", "feature"]) + obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"]) + X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"]) beta = pm.Normal("beta", dims="feature") mu = X_data @ beta y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") - m2 = model_to_minibatch(m1, batch_size=10) - m2["y"].dprint() + with pm.Model( + coords={"feature": range(n_features), "data_dim": range(data_size)} + ) as reference_model: + obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"]) + X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"]) + minibatch_obs_data, minibatch_X_data = pm.Minibatch(obs_data, X_data, batch_size=10) + beta = pm.Normal("beta", dims="feature") + mu = minibatch_X_data @ beta + y = pm.Normal( + "y", + mu=mu, + sigma=1, + observed=minibatch_obs_data, + dims="data_dim", + total_size=(obs_data.shape[0], ...), + ) - assert 0 + m2 = model_to_minibatch(m1, batch_size=10) + assert_equivalent_models(m2, reference_model) def test_remove_minibatches(): From 4d18e37963840b70b1e286bc3198e1a69ba2aed9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 16 Nov 2025 13:25:13 -0600 Subject: [PATCH 3/7] Add `assert_equivalent_models` test helper --- pymc/data.py | 3 +++ pymc/testing.py | 38 +++++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index cfade37910..5094c8f510 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -92,6 +92,9 @@ class MinibatchIndexRV(IntegersRV): class MinibatchOp(OpFromGraph): """Encapsulate Minibatch random draws in an opaque OFG.""" + # FIXME: __props__ should not be empty + __props__ = () + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, inline=True) diff --git a/pymc/testing.py b/pymc/testing.py index 551356ee33..2c3ce51db5 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -47,6 +47,8 @@ ParameterValueError, local_check_parameter_to_ninf_switch, ) +from pymc.model import Model +from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph # This mode can be used for tests where model compilations takes the bulk of the runtime @@ -239,7 +241,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): if extra_args is None: extra_args = {} - with pm.Model() as m: + with Model() as m: param_vars = {} for v, dom in vardomains.items(): v_pt = pytensor.shared(np.asarray(dom.vals[0])) @@ -1209,3 +1211,37 @@ def equal_computations_up_to_root( return False return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type] + + +def assert_equivalent_models(model1: Model, model2: Model): + """Check whether two PyMC models are equivalent. + + Examples + -------- + + .. code-block:: python + + import pymc as pm + from pymc_extras.utils.model_equivalence import equivalent_models + + with pm.Model() as m1: + x = pm.Normal("x") + y = pm.Normal("y", x) + + with pm.Model() as m2: + x = pm.Normal("x") + y = pm.Normal("y", x + 1) + + with pm.Model() as m3: + x = pm.Normal("x") + y = pm.Normal("y", x) + + assert not equivalent_models(m1, m2) + assert equivalent_models(m1, m3) + + """ + fgraph1, _ = fgraph_from_model(model1) + fgraph2, _ = fgraph_from_model(model2) + + are_equivalent = equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs) + assert are_equivalent, "Models are not equivalent" From 9ab26a862b79052e3c8256212e7ef4c307a53d3c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Nov 2025 15:30:07 +0100 Subject: [PATCH 4/7] Cleanup implementation and test --- pymc/model/transform/basic.py | 114 +++++++++++++--------------- tests/model/transform/test_basic.py | 39 +++++++--- 2 files changed, 81 insertions(+), 72 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index e9ffac8731..bdc66469d4 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -64,20 +64,20 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l def model_to_minibatch( - model: Model, *, batch_size: int, vars_to_minibatch: list[str] | None = None + model: Model, *, batch_size: int, minibatch_vars: list[str] | None = None ) -> Model: """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" from pymc.variational.minibatch_rv import create_minibatch_rv - if vars_to_minibatch is None: - vars_to_minibatch = [ + if minibatch_vars is None: + original_minibatch_vars = [ variable for variable in model.data_vars if (variable.type.ndim > 0) and (variable.type.shape[0] is None) ] else: - vars_to_minibatch = parse_vars(model, vars_to_minibatch) - for variable in vars_to_minibatch: + original_minibatch_vars = parse_vars(model, minibatch_vars) + for variable in original_minibatch_vars: if variable.type.ndim == 0: raise ValueError( f"Cannot minibatch {variable.name} because it is a scalar variable." @@ -93,66 +93,58 @@ def model_to_minibatch( fgraph, memo = fgraph_from_model(model, inlined_views=True) - cloned_vars_to_minibatch = [memo[var] for var in vars_to_minibatch] - minibatch_vars = Minibatch(*cloned_vars_to_minibatch, batch_size=batch_size) - - var_to_dummy = { - var: var.type() # model_named(minibatch_var, *extract_dims(var)) - for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars) - } - dummy_to_minibatch = { - var_to_dummy[var]: minibatch_var - for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars) - } - total_size = (cloned_vars_to_minibatch[0].owner.inputs[0].shape[0], ...) - - # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim - # (or just do this all in xtensor) - vtm_set = set(vars_to_minibatch) - - # TODO: Handle potentials, free_RVs, etc - - # Create a temporary fgraph that does not include as outputs any of the variables that will be minibatched. This - # ensures the results of this function match the outputs from a model constructed using the pm.Minibatch API. - tmp_fgraph = FunctionGraph( - outputs=[out for out in fgraph.outputs if out not in var_to_dummy.keys()], clone=False + pre_minibatch_vars = [memo[var] for var in original_minibatch_vars] + minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size) + + # Replace uses of the specified data variables with Minibatch variables + # We need a two-step clone because FunctionGraph can only mutate one variable at a time + # and when there are multiple vars to minibatch you end up replacing the same variable twice recursively + # exampre: out = x + y + # goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)] + # replace x first we get: out = Minibatch(x, y).0 + y + # then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1 + # The second replacement of y ends up creating a circular dependency + pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars) + dummy_to_minibatch_var = tuple( + (dummy, minibatch_var) + for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars) ) - # All variables that will be minibatched are first replaced by dummy variables, to avoid infinite recursion during - # rewrites. The issue is that the Minibatch Op we will introduce depends on the original input variables (to get - # the shapes). That's fine in the final output, but during the intermediate rewrites this creates a circulatr - # dependency. - dummy_replacements = tuple(var_to_dummy.items()) - toposort_replace(tmp_fgraph, dummy_replacements) - - # Now we can replace the dummy variables with the actual Minibatch variables. - replacements = tuple(dummy_to_minibatch.items()) - toposort_replace(tmp_fgraph, replacements) - - # The last step is to replace all RVs that depend on the minibatched variables with MinibatchRVs that are aware - # of the total_size. Importantly, all of the toposort_replace calls above modify fgraph in place, so the - # model.rvs_to_values[original_rv] will already have been modified to depend on the Minibatch variables -- only - # the outer RVs need to be replaced here. - dependent_replacements = {} + # Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves, + # So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs + other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars] + minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False) + minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined] + minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] + toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy) + toposort_replace(minibatch_fgraph, dummy_to_minibatch_var) - for original_rv in model.observed_RVs: - original_value_var = model.rvs_to_values[original_rv] - - if not (set(ancestors([original_rv, original_value_var])) & vtm_set): + # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs + dependent_replacements = {} + total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...) + vars_to_minibatch_set = set(pre_minibatch_vars) + for model_var in minibatch_fgraph.outputs: + if not (set(ancestors([model_var])) & vars_to_minibatch_set): continue - - rv = memo[original_rv].owner.inputs[0] - dependent_replacements[rv] = create_minibatch_rv(rv, total_size=total_size) - - toposort_replace(fgraph, tuple(dependent_replacements.items())) - - # FIXME: The fgraph is being rebuilt here to clean up the clients. It is not clear why they are getting messed up - # in the first place (pytensor bug, or something wrong in the above manipulations?) - new_fgraph = FunctionGraph(outputs=fgraph.outputs) - new_fgraph._coords = fgraph._coords # type: ignore[attr-defined] - new_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] - - return model_from_fgraph(new_fgraph, mutate_fgraph=True) + if not isinstance(model_var.owner.op, ModelObservedRV): + raise ValueError( + "Minibatching only supports observed RVs depending on minibatched variables. " + f"Found dependent unobserved variable: {model_var.name}." + ) + # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim + # And conversely other variables do not have that dim + observed_rv = model_var.owner.inputs[0] + dependent_replacements[observed_rv] = create_minibatch_rv( + observed_rv, total_size=total_size + ) + + toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items())) + + # Finally reintroduce the original data variable outputs + for pre_minibatch_var in pre_minibatch_vars: + minibatch_fgraph.add_output(pre_minibatch_var) + + return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True) def remove_minibatched_nodes(model: Model) -> Model: diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 757e866392..75dd6bf9bc 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -20,7 +20,6 @@ prune_vars_detached_from_observed, remove_minibatched_nodes, ) -from pymc.testing import assert_equivalent_models def test_prune_vars_detached_from_observed(): @@ -43,25 +42,22 @@ def test_model_to_minibatch(): data_size = 100 n_features = 4 - obs_data_np = np.zeros((data_size,)) + obs_data_np = np.random.normal(size=(data_size,)) X_data_np = np.random.normal(size=(data_size, n_features)) - with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1: + with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m: obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"]) X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"]) - beta = pm.Normal("beta", dims="feature") + beta = pm.Normal("beta", mu=np.pi, dims="feature") mu = X_data @ beta - y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") - with pm.Model( - coords={"feature": range(n_features), "data_dim": range(data_size)} - ) as reference_model: + with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as ref_m: obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"]) X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"]) minibatch_obs_data, minibatch_X_data = pm.Minibatch(obs_data, X_data, batch_size=10) - beta = pm.Normal("beta", dims="feature") + beta = pm.Normal("beta", mu=np.pi, dims="feature") mu = minibatch_X_data @ beta y = pm.Normal( "y", @@ -72,8 +68,29 @@ def test_model_to_minibatch(): total_size=(obs_data.shape[0], ...), ) - m2 = model_to_minibatch(m1, batch_size=10) - assert_equivalent_models(m2, reference_model) + mb = model_to_minibatch(m, batch_size=10) + mb_logp_fn = mb.compile_logp(random_seed=42) + ref_mb_logp_fn = ref_m.compile_logp(random_seed=42) + ip = mb.initial_point() + + mb_res1 = mb_logp_fn(ip) + ref_mb_res1 = ref_mb_logp_fn(ip) + np.testing.assert_allclose(mb_res1, ref_mb_res1) + mb_res2 = mb_logp_fn(ip) + # Minibatch should give different results on each call + assert mb_res1 != mb_res2 + ref_mb_res2 = ref_mb_logp_fn(ip) + np.testing.assert_allclose(mb_res2, ref_mb_res2) + + m_again = remove_minibatched_nodes(mb) + m_again_logp_fn = m_again.compile_logp(random_seed=42) + m_logp_fn = m_again.compile_logp(random_seed=42) + ip = m_again.initial_point() + m_again_res = m_again_logp_fn(ip) + m_res = m_logp_fn(ip) + np.testing.assert_allclose(m_again_res, m_res) + # Check that repeated calls give the same result (no more minibatching) + np.testing.assert_allclose(m_again_res, m_again_logp_fn(ip)) def test_remove_minibatches(): From a35046709089adf31b62d730916aad0df46e15ce Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Nov 2025 15:30:12 +0100 Subject: [PATCH 5/7] Revert "Add `assert_equivalent_models` test helper" This reverts commit 4d18e37963840b70b1e286bc3198e1a69ba2aed9. --- pymc/data.py | 3 --- pymc/testing.py | 38 +------------------------------------- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 5094c8f510..cfade37910 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -92,9 +92,6 @@ class MinibatchIndexRV(IntegersRV): class MinibatchOp(OpFromGraph): """Encapsulate Minibatch random draws in an opaque OFG.""" - # FIXME: __props__ should not be empty - __props__ = () - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, inline=True) diff --git a/pymc/testing.py b/pymc/testing.py index 2c3ce51db5..551356ee33 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -47,8 +47,6 @@ ParameterValueError, local_check_parameter_to_ninf_switch, ) -from pymc.model import Model -from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph # This mode can be used for tests where model compilations takes the bulk of the runtime @@ -241,7 +239,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): if extra_args is None: extra_args = {} - with Model() as m: + with pm.Model() as m: param_vars = {} for v, dom in vardomains.items(): v_pt = pytensor.shared(np.asarray(dom.vals[0])) @@ -1211,37 +1209,3 @@ def equal_computations_up_to_root( return False return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type] - - -def assert_equivalent_models(model1: Model, model2: Model): - """Check whether two PyMC models are equivalent. - - Examples - -------- - - .. code-block:: python - - import pymc as pm - from pymc_extras.utils.model_equivalence import equivalent_models - - with pm.Model() as m1: - x = pm.Normal("x") - y = pm.Normal("y", x) - - with pm.Model() as m2: - x = pm.Normal("x") - y = pm.Normal("y", x + 1) - - with pm.Model() as m3: - x = pm.Normal("x") - y = pm.Normal("y", x) - - assert not equivalent_models(m1, m2) - assert equivalent_models(m1, m3) - - """ - fgraph1, _ = fgraph_from_model(model1) - fgraph2, _ = fgraph_from_model(model2) - - are_equivalent = equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs) - assert are_equivalent, "Models are not equivalent" From 0fbc7d9fa3383d6cee466cdb0a90aff39cb51e75 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Nov 2025 16:02:47 +0100 Subject: [PATCH 6/7] broken WIP move minibatch transform and rename/rework --- pymc/model/transform/basic.py | 111 +------------ pymc/model/transform/minibatch.py | 201 ++++++++++++++++++++++++ tests/model/transform/test_basic.py | 96 ++--------- tests/model/transform/test_minibatch.py | 92 +++++++++++ 4 files changed, 304 insertions(+), 196 deletions(-) create mode 100644 pymc/model/transform/minibatch.py create mode 100644 tests/model/transform/test_minibatch.py diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index bdc66469d4..fcf42fdf8c 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -13,11 +13,9 @@ # limitations under the License. from collections.abc import Sequence -from pytensor import Variable, clone_replace +from pytensor import Variable from pytensor.graph import ancestors -from pytensor.graph.fg import FunctionGraph -from pymc.data import Minibatch, MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, @@ -25,7 +23,6 @@ fgraph_from_model, model_from_fgraph, ) -from pymc.pytensorf import toposort_replace ModelVariable = Variable | str @@ -61,109 +58,3 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l else: vars_seq = (vars,) return [model[var] if isinstance(var, str) else var for var in vars_seq] - - -def model_to_minibatch( - model: Model, *, batch_size: int, minibatch_vars: list[str] | None = None -) -> Model: - """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" - from pymc.variational.minibatch_rv import create_minibatch_rv - - if minibatch_vars is None: - original_minibatch_vars = [ - variable - for variable in model.data_vars - if (variable.type.ndim > 0) and (variable.type.shape[0] is None) - ] - else: - original_minibatch_vars = parse_vars(model, minibatch_vars) - for variable in original_minibatch_vars: - if variable.type.ndim == 0: - raise ValueError( - f"Cannot minibatch {variable.name} because it is a scalar variable." - ) - if variable.type.shape[0] is not None: - raise ValueError( - f"Cannot minibatch {variable.name} because its first dimension is static " - f"(size={variable.type.shape[0]})." - ) - - # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed - # shape, but mu from data --> y cannot be minibatched because of sigma. - - fgraph, memo = fgraph_from_model(model, inlined_views=True) - - pre_minibatch_vars = [memo[var] for var in original_minibatch_vars] - minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size) - - # Replace uses of the specified data variables with Minibatch variables - # We need a two-step clone because FunctionGraph can only mutate one variable at a time - # and when there are multiple vars to minibatch you end up replacing the same variable twice recursively - # exampre: out = x + y - # goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)] - # replace x first we get: out = Minibatch(x, y).0 + y - # then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1 - # The second replacement of y ends up creating a circular dependency - pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars) - dummy_to_minibatch_var = tuple( - (dummy, minibatch_var) - for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars) - ) - - # Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves, - # So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs - other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars] - minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False) - minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined] - minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] - toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy) - toposort_replace(minibatch_fgraph, dummy_to_minibatch_var) - - # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs - dependent_replacements = {} - total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...) - vars_to_minibatch_set = set(pre_minibatch_vars) - for model_var in minibatch_fgraph.outputs: - if not (set(ancestors([model_var])) & vars_to_minibatch_set): - continue - if not isinstance(model_var.owner.op, ModelObservedRV): - raise ValueError( - "Minibatching only supports observed RVs depending on minibatched variables. " - f"Found dependent unobserved variable: {model_var.name}." - ) - # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim - # And conversely other variables do not have that dim - observed_rv = model_var.owner.inputs[0] - dependent_replacements[observed_rv] = create_minibatch_rv( - observed_rv, total_size=total_size - ) - - toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items())) - - # Finally reintroduce the original data variable outputs - for pre_minibatch_var in pre_minibatch_vars: - minibatch_fgraph.add_output(pre_minibatch_var) - - return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True) - - -def remove_minibatched_nodes(model: Model) -> Model: - """Remove all uses of pm.Minibatch in the Model.""" - fgraph, _ = fgraph_from_model(model) - - replacements = {} - for var in fgraph.apply_nodes: - if isinstance(var.op, MinibatchOp): - for inp, out in zip(var.inputs, var.outputs): - replacements[out] = inp - - old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] - # Using `rebuild_strict=False` means all coords, names, and dim information is lost - # So we need to restore it from the old fgraph - new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] - for old_out, new_out in zip(old_outs, new_outs): - new_out.name = old_out.name - fgraph = FunctionGraph(outputs=new_outs, clone=False) - fgraph._coords = old_coords # type: ignore[attr-defined] - fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] - return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/pymc/model/transform/minibatch.py b/pymc/model/transform/minibatch.py new file mode 100644 index 0000000000..e567d45c8a --- /dev/null +++ b/pymc/model/transform/minibatch.py @@ -0,0 +1,201 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence + +from pytensor import Variable +from pytensor.graph import FunctionGraph, ancestors + +from build.lib.pymc.variational.minibatch_rv import MinibatchRandomVariable +from pymc import Minibatch, Model +from pymc.data import MinibatchOp +from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph +from pymc.model.transform.basic import parse_vars +from pymc.pytensorf import toposort_replace + + +def minibatch_model( + model: Model, + *, + batch_size: int, + minibatch_vars: Sequence[str | Variable] | None = None, +) -> Model: + """Create a minibatch version of the given Model. + + Replaces minibatch_vars data containers with Minibatch views and rescales the logp of dependent observed variables. + + .. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid. + + Parameters + ---------- + model : Model + The original model to transform. + batch_size : int + The minibatch size to use. + minibatch_vars : Sequence of Variable or string, optional + Data variables to convert to minibatch. If None, all data variables with a leading dimension of size None will be minibatched. + + Returns + ------- + Model + A new Model with the specified data variables replaced by Minibatch views and dependent observed RVs adjusted accordingly. + + Raises + ------ + ValueError + If any of the specified variables cannot be minibatched (e.g., scalar variables or variables with static leading dimensions), or if dependent variables are Potentials / Unobserved RVs. + + Examples + -------- + .. code-block:: python + + import numpy as np + import pymc as pm + from pymc.model.transform.minibatch import minibatch_model + + with pm.Model() as m: + obs_data = pm.Data("obs_data", np.random.normal(size=(100,))) + X_data = pm.Data("X_data", np.random.normal(size=(100, 4))) + beta = pm.Normal("beta", mu=np.pi, dims="feature") + + mu = X_data @ beta + y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data) + + with minibatch_model(m, batch_size=10) as mb: + pm.fit() + """ + from pymc.variational.minibatch_rv import create_minibatch_rv + + if minibatch_vars is None: + original_minibatch_vars = [ + variable + for variable in model.data_vars + if (variable.type.ndim > 0) and (variable.type.shape[0] is None) + ] + else: + original_minibatch_vars = parse_vars(model, minibatch_vars) + for variable in original_minibatch_vars: + if variable.type.ndim == 0: + raise ValueError( + f"Cannot minibatch {variable.name} because it is a scalar variable." + ) + if variable.type.shape[0] is not None: + raise ValueError( + f"Cannot minibatch {variable.name} because its first dimension is static " + f"(size={variable.type.shape[0]})." + ) + + # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed + # shape, but mu from data --> y cannot be minibatched because of sigma. + + fgraph, memo = fgraph_from_model(model, inlined_views=True) + + pre_minibatch_vars = [memo[var] for var in original_minibatch_vars] + minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size) + + # Replace uses of the specified data variables with Minibatch variables + # We need a two-step clone because FunctionGraph can only mutate one variable at a time + # and when there are multiple vars to minibatch you end up replacing the same variable twice recursively + # exampre: out = x + y + # goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)] + # replace x first we get: out = Minibatch(x, y).0 + y + # then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1 + # The second replacement of y ends up creating a circular dependency + pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars) + dummy_to_minibatch_var = tuple( + (dummy, minibatch_var) + for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars) + ) + + # Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves, + # So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs + other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars] + minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False) + minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined] + minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] + toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy) + toposort_replace(minibatch_fgraph, dummy_to_minibatch_var) + + # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs + dependent_replacements = {} + total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...) + vars_to_minibatch_set = set(pre_minibatch_vars) + for model_var in minibatch_fgraph.outputs: + if not (set(ancestors([model_var])) & vars_to_minibatch_set): + continue + if not isinstance(model_var.owner.op, ModelObservedRV): + raise ValueError( + "Minibatching only supports observed RVs depending on minibatched variables. " + f"Found dependent unobserved variable: {model_var.name}." + ) + # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim + # And conversely other variables do not have that dim + observed_rv = model_var.owner.inputs[0] + dependent_replacements[observed_rv] = create_minibatch_rv( + observed_rv, total_size=total_size + ) + + toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items())) + + # Finally reintroduce the original data variable outputs + for pre_minibatch_var in pre_minibatch_vars: + minibatch_fgraph.add_output(pre_minibatch_var) + + return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True) + + +def remove_minibatch(model: Model) -> Model: + """Remove all uses of Minibatch data and random variables from the Model. + + Parameters + ---------- + model : Model + The original model to transform. + + Returns + ------- + Model + A new Model with all Minibatch data variables and MinibatchRVs replaced by their original counterparts. + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc.model.transform.minibatch import undo_minibatch + + with pm.Model() as mb: + X_data = pm.Data("X_data", [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + obs_data = pm.Data("obs_data", [1, 2, 3, 4, 5]) + minibatch_X_data, minibatch_obs_data = pm.Minibatch(X_data, obs_data, batch_size=3) + + beta = pm.Normal("beta", shape=(2,)) + mu = minibatch_X_data @ beta + y = pm.Normal("y", mu=mu, sigma=1, observed=minibatch_obs_data, total_size=(5,)) + + with undo_minibatch(mb) as m: + idata = pm.sample_prior_predictive() + assert idata.prior["y"].shape[-1] == 5 # Original data size restored + + """ + fgraph, _ = fgraph_from_model(model) + + replacements = [] + for var in fgraph.apply_nodes: + if isinstance(var.op, MinibatchOp): + replacements.extend(zip(var.inputs, var.outputs)) + elif isinstance(var.op, MinibatchRandomVariable): + replacements.append((var.outputs[0], var.inputs[0])) + + toposort_replace(fgraph, replacements) + return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 75dd6bf9bc..4042ee0da3 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -11,100 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - -import pymc as pm - +from pymc import Data, Model +from pymc.distributions import Normal from pymc.model.transform.basic import ( - model_to_minibatch, prune_vars_detached_from_observed, - remove_minibatched_nodes, ) def test_prune_vars_detached_from_observed(): - with pm.Model() as m: - obs_data = pm.Data("obs_data", 0) - a0 = pm.Data("a0", 0) - a1 = pm.Normal("a1", a0) - a2 = pm.Normal("a2", a1) - pm.Normal("obs", a2, observed=obs_data) + with Model() as m: + obs_data = Data("obs_data", 0) + a0 = Data("a0", 0) + a1 = Normal("a1", a0) + a2 = Normal("a2", a1) + Normal("obs", a2, observed=obs_data) - d0 = pm.Data("d0", 0) - d1 = pm.Normal("d1", d0) + d0 = Data("d0", 0) + d1 = Normal("d1", d0) assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} pruned_m = prune_vars_detached_from_observed(m) assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} - - -def test_model_to_minibatch(): - data_size = 100 - n_features = 4 - - obs_data_np = np.random.normal(size=(data_size,)) - X_data_np = np.random.normal(size=(data_size, n_features)) - - with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m: - obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"]) - X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"]) - beta = pm.Normal("beta", mu=np.pi, dims="feature") - - mu = X_data @ beta - y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") - - with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as ref_m: - obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"]) - X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"]) - minibatch_obs_data, minibatch_X_data = pm.Minibatch(obs_data, X_data, batch_size=10) - beta = pm.Normal("beta", mu=np.pi, dims="feature") - mu = minibatch_X_data @ beta - y = pm.Normal( - "y", - mu=mu, - sigma=1, - observed=minibatch_obs_data, - dims="data_dim", - total_size=(obs_data.shape[0], ...), - ) - - mb = model_to_minibatch(m, batch_size=10) - mb_logp_fn = mb.compile_logp(random_seed=42) - ref_mb_logp_fn = ref_m.compile_logp(random_seed=42) - ip = mb.initial_point() - - mb_res1 = mb_logp_fn(ip) - ref_mb_res1 = ref_mb_logp_fn(ip) - np.testing.assert_allclose(mb_res1, ref_mb_res1) - mb_res2 = mb_logp_fn(ip) - # Minibatch should give different results on each call - assert mb_res1 != mb_res2 - ref_mb_res2 = ref_mb_logp_fn(ip) - np.testing.assert_allclose(mb_res2, ref_mb_res2) - - m_again = remove_minibatched_nodes(mb) - m_again_logp_fn = m_again.compile_logp(random_seed=42) - m_logp_fn = m_again.compile_logp(random_seed=42) - ip = m_again.initial_point() - m_again_res = m_again_logp_fn(ip) - m_res = m_logp_fn(ip) - np.testing.assert_allclose(m_again_res, m_res) - # Check that repeated calls give the same result (no more minibatching) - np.testing.assert_allclose(m_again_res, m_again_logp_fn(ip)) - - -def test_remove_minibatches(): - data_size = 100 - data = np.zeros((data_size,)) - batch_size = 10 - with pm.Model(coords={"d": range(5)}) as m1: - mb = pm.Minibatch(data, batch_size=batch_size) - mu = pm.Normal("mu", dims="d") - x = pm.Normal("x") - y = pm.Normal("y", x, observed=mb, total_size=100) - - m2 = remove_minibatched_nodes(m1) - assert m1.y.shape[0].eval() == batch_size - assert m2.y.shape[0].eval() == data_size - assert m1.coords == m2.coords - assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval() diff --git a/tests/model/transform/test_minibatch.py b/tests/model/transform/test_minibatch.py new file mode 100644 index 0000000000..771577f1a4 --- /dev/null +++ b/tests/model/transform/test_minibatch.py @@ -0,0 +1,92 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from pymc.data import Data, Minibatch +from pymc.distributions import Normal +from pymc.model.core import Model +from pymc.model.transform.minibatch import minibatch_model, remove_minibatch + + +def test_minibatch_model(): + data_size = 100 + n_features = 4 + + obs_data_np = np.random.normal(size=(data_size,)) + X_data_np = np.random.normal(size=(data_size, n_features)) + + with Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m: + obs_data = Data("obs_data", obs_data_np, dims=["data_dim"]) + X_data = Data("X_data", X_data_np, dims=["data_dim", "feature"]) + beta = Normal("beta", mu=np.pi, dims="feature") + + mu = X_data @ beta + y = Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") + + with Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as ref_m: + obs_data = Data("obs_data", obs_data_np, dims=["data_dim"]) + X_data = Data("X_data", X_data_np, dims=["data_dim", "feature"]) + minibatch_obs_data, minibatch_X_data = Minibatch(obs_data, X_data, batch_size=10) + beta = Normal("beta", mu=np.pi, dims="feature") + mu = minibatch_X_data @ beta + y = Normal( + "y", + mu=mu, + sigma=1, + observed=minibatch_obs_data, + dims="data_dim", + total_size=(obs_data.shape[0], ...), + ) + + mb = minibatch_model(m, batch_size=10) + mb_logp_fn = mb.compile_logp(random_seed=42) + ref_mb_logp_fn = ref_m.compile_logp(random_seed=42) + ip = mb.initial_point() + + mb_res1 = mb_logp_fn(ip) + ref_mb_res1 = ref_mb_logp_fn(ip) + np.testing.assert_allclose(mb_res1, ref_mb_res1) + mb_res2 = mb_logp_fn(ip) + # Minibatch should give different results on each call + assert mb_res1 != mb_res2 + ref_mb_res2 = ref_mb_logp_fn(ip) + np.testing.assert_allclose(mb_res2, ref_mb_res2) + + # Test round-trip minibatch -> remove_minibatch + m_again = remove_minibatch(mb) + m_again_logp_fn = m_again.compile_logp(random_seed=42) + m_logp_fn = m_again.compile_logp(random_seed=42) + ip = m_again.initial_point() + m_again_res = m_again_logp_fn(ip) + m_res = m_logp_fn(ip) + np.testing.assert_allclose(m_again_res, m_res) + # Check that repeated calls give the same result (no more minibatching) + np.testing.assert_allclose(m_again_res, m_again_logp_fn(ip)) + + +def test_remove_minibatches(): + data_size = 100 + data = np.zeros((data_size,)) + batch_size = 10 + with Model(coords={"d": range(5)}) as m1: + mb = Minibatch(data, batch_size=batch_size) + mu = Normal("mu", dims="d") + x = Normal("x") + y = Normal("y", x, observed=mb, total_size=100) + + m2 = remove_minibatch(m1) + assert m1.y.shape[0].eval() == batch_size + assert m2.y.shape[0].eval() == data_size + assert m1.coords == m2.coords + assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval() From a3bc54bcfa2dc43580c5a15fe65260bcb3837467 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Nov 2025 19:05:40 +0100 Subject: [PATCH 7/7] WIP Allow rebuilding graph in toposort_replace --- pymc/model/core.py | 2 +- pymc/model/transform/minibatch.py | 38 ++++---- pymc/pytensorf.py | 111 +++++++++++++++++++++++- pymc/variational/minibatch_rv.py | 95 ++++++++++++-------- tests/model/transform/test_minibatch.py | 70 ++++++++++----- tests/test_pytensorf.py | 67 ++++++++++++++ 6 files changed, 305 insertions(+), 78 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 69e1fbed72..6f286ee9c8 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1355,7 +1355,7 @@ def make_obs_var( elif not isinstance(data, Variable): data = pt.as_tensor_variable(data, name=name) - if total_size: + if total_size is not None: from pymc.variational.minibatch_rv import create_minibatch_rv rv_var = create_minibatch_rv(rv_var, total_size) diff --git a/pymc/model/transform/minibatch.py b/pymc/model/transform/minibatch.py index e567d45c8a..364e1054c9 100644 --- a/pymc/model/transform/minibatch.py +++ b/pymc/model/transform/minibatch.py @@ -16,12 +16,12 @@ from pytensor import Variable from pytensor.graph import FunctionGraph, ancestors -from build.lib.pymc.variational.minibatch_rv import MinibatchRandomVariable from pymc import Minibatch, Model from pymc.data import MinibatchOp from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph from pymc.model.transform.basic import parse_vars from pymc.pytensorf import toposort_replace +from pymc.variational.minibatch_rv import MinibatchRandomVariable def minibatch_model( @@ -36,6 +36,8 @@ def minibatch_model( .. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid. + .. warning:: When minibatch_vars are not specified, all non-scalar data variables will be minibatch. This can be incorrect! + Parameters ---------- model : Model @@ -43,7 +45,7 @@ def minibatch_model( batch_size : int The minibatch size to use. minibatch_vars : Sequence of Variable or string, optional - Data variables to convert to minibatch. If None, all data variables with a leading dimension of size None will be minibatched. + Data variables to convert to minibatch. If None, all non scalar data variables will be minibatched. Returns ------- @@ -78,9 +80,7 @@ def minibatch_model( if minibatch_vars is None: original_minibatch_vars = [ - variable - for variable in model.data_vars - if (variable.type.ndim > 0) and (variable.type.shape[0] is None) + variable for variable in model.data_vars if variable.type.ndim > 0 ] else: original_minibatch_vars = parse_vars(model, minibatch_vars) @@ -89,11 +89,6 @@ def minibatch_model( raise ValueError( f"Cannot minibatch {variable.name} because it is a scalar variable." ) - if variable.type.shape[0] is not None: - raise ValueError( - f"Cannot minibatch {variable.name} because its first dimension is static " - f"(size={variable.type.shape[0]})." - ) # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed # shape, but mu from data --> y cannot be minibatched because of sigma. @@ -124,10 +119,10 @@ def minibatch_model( minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined] minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy) - toposort_replace(minibatch_fgraph, dummy_to_minibatch_var) + toposort_replace(minibatch_fgraph, dummy_to_minibatch_var, rebuild=True) # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs - dependent_replacements = {} + dependent_replacements = [] total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...) vars_to_minibatch_set = set(pre_minibatch_vars) for model_var in minibatch_fgraph.outputs: @@ -141,11 +136,10 @@ def minibatch_model( # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim # And conversely other variables do not have that dim observed_rv = model_var.owner.inputs[0] - dependent_replacements[observed_rv] = create_minibatch_rv( - observed_rv, total_size=total_size - ) + minibatch_rv = create_minibatch_rv(observed_rv, total_size=total_size) + dependent_replacements.append((observed_rv, minibatch_rv)) - toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items())) + toposort_replace(minibatch_fgraph, dependent_replacements, rebuild=True) # Finally reintroduce the original data variable outputs for pre_minibatch_var in pre_minibatch_vars: @@ -191,11 +185,11 @@ def remove_minibatch(model: Model) -> Model: fgraph, _ = fgraph_from_model(model) replacements = [] - for var in fgraph.apply_nodes: - if isinstance(var.op, MinibatchOp): - replacements.extend(zip(var.inputs, var.outputs)) - elif isinstance(var.op, MinibatchRandomVariable): - replacements.append((var.outputs[0], var.inputs[0])) + for node in fgraph.apply_nodes: + if isinstance(node.op, MinibatchOp): + replacements.extend(zip(node.outputs[:-1], node.inputs[:-1])) + elif isinstance(node.op, MinibatchRandomVariable): + replacements.append((node.outputs[0], node.inputs[0])) - toposort_replace(fgraph, replacements) + toposort_replace(fgraph, replacements, rebuild=True) return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index d7e097f6dc..e320a93bf5 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -13,7 +13,9 @@ # limitations under the License. import warnings +from collections import deque from collections.abc import Iterable, Sequence +from itertools import chain from typing import cast import numpy as np @@ -1032,12 +1034,105 @@ def as_symbolic_string(x, **kwargs): return StringConstant(stringtype, x) +def _replace_rebuild( + fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], **kwargs +) -> FunctionGraph: + """Replace variables and rebuild dependent graph if needed. + + Rebuilding allows for replacements that change the semantics of the graph + (different types), which may not be possible for all Ops. + """ + fg_clients = fgraph.clients + fg_variables = fgraph.variables + + def get_client_nodes(vars) -> set[Apply]: + # Start with the immediate clients of vars + nodes = set() + d = list(chain.from_iterable(fg_clients[var] for var in vars if var in fg_variables)) + while d: + node, _ = d.pop() + if node in nodes or isinstance(node.op, Output): + continue + nodes.add(node) + # Keep walking to the successor clients + d.extend(chain.from_iterable(fg_clients[out] for out in node.outputs)) + return nodes + + repl_dict = dict(replacements) + root_nodes = {var.owner for var in repl_dict.keys()} + + # Build sorted queue with all nodes that depend on replaced variables + topo_order = {node: order for order, node in enumerate(fgraph.toposort())} + d = deque( + sorted( + get_client_nodes(repl_dict.keys()), + key=lambda node: topo_order[node], + ) + ) + while d: + node = d.popleft() + if node in root_nodes: + continue + + new_inputs = [repl_dict.get(i, i) for i in node.inputs] + if new_inputs == node.inputs: + continue + + # We need to remake the node if: + # 1. The output type depends on an input value + # 2. Any of the input type changed + if getattr(node.op, "_output_type_depends_on_input_value", False): + remake_node = True + else: + remake_node = any( + inp.type != new_inp.type for inp, new_inp in zip(node.inputs, new_inputs) + ) + + if remake_node: + new_node = node.clone_with_new_inputs(new_inputs, strict=False) + fgraph.import_node(new_node, import_missing=True) + + # We are not always allowed to call `fgraph.replace_all` because the output types may be incompatible + # We will keep the changes in repl_dict until we can replace a node without remaking it, + # or we arrive to the end of the graph, in which case we need to replace the FunctionGraph output + for out, new_out in zip(node.outputs, new_node.outputs): + new_out.name = out.name + repl_dict[out] = new_out + else: + fgraph.replace_all(tuple(zip(node.inputs, new_inputs)), import_missing=True) + + # If the FunctionGraph outputs themselves were rebuilt we need to handle them + for i, (new_output, old_output) in enumerate( + zip( + (repl_dict.get(out, out) for out in fgraph.outputs), + fgraph.outputs, + ) + ): + if new_output is old_output: + continue + fgraph.outputs[i] = new_output + fgraph.import_var(new_output, import_missing=True) + fgraph.clients[new_output] = [ + # Output variables have a special Output Op client + # We need to transfer it to the new output. + # Any other uses of this output variable will already have been substituted in the loop above, + # or are part of other outputs we will subsitute next + (cl.op.make_node(new_output), idx) if isinstance(cl.op, Output) else (cl, idx) + for cl, idx in fgraph.clients[old_output] + ] + return fgraph + + def toposort_replace( fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False, + rebuild: bool = False, ) -> None: """Replace multiple variables in place in topological order.""" + if rebuild and reverse: + raise NotImplementedError("reverse rebuild not yet supported") + fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())} fgraph_toposort[None] = -1 # Variables without owner are not in the toposort sorted_replacements = sorted( @@ -1045,7 +1140,21 @@ def toposort_replace( key=lambda pair: fgraph_toposort[pair[0].owner], reverse=reverse, ) - fgraph.replace_all(sorted_replacements, import_missing=True) + + if rebuild: + if len(replacements) > 1: + # In this case we need to modify the replacements recursively with each other + sorted_replacements = [list(pairs) for pairs in sorted_replacements] + for i in range(1, len(replacements)): + temp_fgraph = FunctionGraph( + outputs=[repl for _, repl in sorted_replacements[i:]], + clone=False, + ) + _replace_rebuild(temp_fgraph, replacements=sorted_replacements[:i]) + sorted_replacements[i][1] = temp_fgraph.outputs[0] + _replace_rebuild(fgraph, sorted_replacements) + else: + fgraph.replace_all(sorted_replacements, import_missing=True) def normalize_rng_param(rng: None | Variable) -> Variable: diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 34d30cfa50..133165b279 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -19,6 +19,7 @@ from pytensor import Variable, config from pytensor.graph import Apply, Op from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable +from pytensor.tensor.type_other import NoneTypeT from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import logp @@ -33,7 +34,9 @@ class MinibatchRandomVariable(MeasurableOp, Op): def make_node(self, rv, *total_size): rv = as_tensor_variable(rv) total_size = [ - as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst + t + if isinstance(t, Variable) + else (NoneConst if t is None else as_tensor_variable(t, dtype="int64", ndim=0)) for t in total_size ] assert len(total_size) == rv.ndim @@ -55,45 +58,67 @@ def perform(self, node, inputs, output_storage): def create_minibatch_rv( rv: TensorVariable, - total_size: int | None | Sequence[int | EllipsisType | None], + total_size: int | TensorVariable | Sequence[int | TensorVariable | EllipsisType | None], ) -> TensorVariable: """Create variable whose logp is rescaled by total_size.""" + rv_ndim_supp = rv.owner.op.ndim_supp + if isinstance(total_size, int): - if rv.ndim <= 1: - total_size = [total_size] + total_size = (total_size, *([None] * rv_ndim_supp)) + elif isinstance(total_size, TensorVariable): + if total_size.type.ndim == 0: + total_size = (total_size, *([None] * rv_ndim_supp)) + elif total_size.type.ndim == 1: + total_size = tuple(total_size) else: - missing_ndims = rv.ndim - 1 - total_size = [total_size] + [None] * missing_ndims - elif isinstance(total_size, list | tuple): - total_size = list(total_size) - if Ellipsis in total_size: - # Replace Ellipsis by None - if total_size.count(Ellipsis) > 1: - raise ValueError("Only one Ellipsis can be present in total_size") - sep = total_size.index(Ellipsis) - begin = total_size[:sep] - end = total_size[sep + 1 :] - missing_ndims = max((rv.ndim - len(begin) - len(end), 0)) - total_size = begin + [None] * missing_ndims + end - if len(total_size) > rv.ndim: - raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}") - else: - raise TypeError(f"Invalid type for total_size: {total_size}") - - return cast(TensorVariable, minibatch_rv(rv, *total_size)) - - -def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable: + raise ValueError( + f"Total size must be a 0d or 1d vector got {total_size} with {total_size.type.ndim} dimensions" + ) + + if not isinstance(total_size, list | tuple): + raise ValueError(f"Invalid type for total_size {total_size}: {type(total_size)}") + + if Ellipsis in total_size: + # Replace Ellipsis by None + if total_size.count(Ellipsis) > 1: + raise ValueError("Only one Ellipsis can be present in total_size") + sep = total_size.index(Ellipsis) + begin = total_size[:sep] + end = total_size[sep + 1 :] + missing_ndims = max((rv_ndim_supp - len(begin) - len(end), 0)) + total_size = (*begin, *([None] * missing_ndims), *end) + + if (len(total_size) - rv_ndim_supp) not in (0, 1): + raise ValueError( + f"Length of total_size {total_size} not compatble with ndim_supp of RV {rv}, " + f"got {len(total_size)} but must be {rv_ndim_supp} or {rv_ndim_supp - 1}" + ) + + out = minibatch_rv(rv, *total_size) + assert isinstance(out.owner.op, MinibatchRandomVariable) + return cast(TensorVariable, out) + + +def get_scaling( + total_size: Sequence[TensorVariable], shape: TensorVariable | Sequence[TensorVariable] +) -> TensorVariable: """Get scaling constant for logp.""" # mypy doesn't understand we can convert a shape TensorVariable into a tuple - shape = tuple(shape) # type: ignore[assignment] + shape = tuple(shape) + + if len(total_size) == (len(shape) - 1): + # This happens when RV has no batch dimensions + # In that case the total_size corresponds to a dummy shape of 1 + total_size = (1, *total_size) + + assert len(shape) == len(total_size) - # Scalar RV - if len(shape) == 0: # type: ignore[arg-type] - coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0 - else: - coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)] - coef = pt.prod(coefs) + coefs = [ + size / dim_length + for size, dim_length in zip(total_size, shape) + if not isinstance(size.type, NoneTypeT) + ] + coef = pt.prod(coefs) if len(coefs) > 1 else coefs[0] return pt.cast(coef, dtype=config.floatX) @@ -102,4 +127,6 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor def minibatch_rv_logprob(op, values, *inputs, **kwargs): [value] = values rv, *total_size = inputs - return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape) + raw_logp = logp(rv, value, **kwargs) + scaled_logp = raw_logp * get_scaling(total_size, raw_logp.shape) + return scaled_logp diff --git a/tests/model/transform/test_minibatch.py b/tests/model/transform/test_minibatch.py index 771577f1a4..20b9c0a2a7 100644 --- a/tests/model/transform/test_minibatch.py +++ b/tests/model/transform/test_minibatch.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np +import pytest from pymc.data import Data, Minibatch -from pymc.distributions import Normal +from pymc.distributions import HalfNormal, Normal from pymc.model.core import Model from pymc.model.transform.minibatch import minibatch_model, remove_minibatch +from pymc.variational.minibatch_rv import MinibatchRandomVariable def test_minibatch_model(): @@ -63,8 +65,53 @@ def test_minibatch_model(): ref_mb_res2 = ref_mb_logp_fn(ip) np.testing.assert_allclose(mb_res2, ref_mb_res2) - # Test round-trip minibatch -> remove_minibatch - m_again = remove_minibatch(mb) + +def test_remove_minibatch(): + data_size = 100 + n_features = 5 + batch_size = 10 + with Model(coords={"d": range(n_features)}) as mb: + X_data = Data("X_data", np.random.normal(size=(data_size, n_features))) + obs_data = Data("obs_data", [1, 2, 3, 4, 5]) + minibatch_X_data, minibatch_obs_data = Minibatch(X_data, obs_data, batch_size=batch_size) + + beta = Normal("beta", dims=("d",)) + mu = minibatch_X_data @ beta + sigma = HalfNormal("sigma") + y = Normal("y", mu=mu, sigma=sigma, observed=minibatch_obs_data, total_size=X_data.shape[0]) + + m = remove_minibatch(mb) + assert isinstance(mb.y.owner.op, MinibatchRandomVariable) + assert tuple(mb.y.shape).eval() == (batch_size,) + assert isinstance(m.y.owner.op, Normal) + assert tuple(m.y.shape.eval()) == (data_size,) + assert mb.coords == m.coords + assert mb.dim_lengths["d"].eval() == m.dim_lengths["d"].eval() + + +@pytest.mark.parametrize("static_shape", (True, False)) +def test_minibatch_transform_roundtrip(static_shape): + data_size = 100 + n_features = 4 + with Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m: + obs_data = Data( + "obs_data", + np.random.normal(size=(data_size,)), + dims=["data_dim"], + shape=(data_size if static_shape else None,), + ) + X_data = Data( + "X_data", + np.random.normal(size=(data_size, n_features)), + dims=["data_dim", "feature"], + shape=(data_size if static_shape else None, n_features), + ) + beta = Normal("beta", mu=np.pi, dims="feature") + + mu = X_data @ beta + y = Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") + + m_again = remove_minibatch(minibatch_model(m, batch_size=10)) m_again_logp_fn = m_again.compile_logp(random_seed=42) m_logp_fn = m_again.compile_logp(random_seed=42) ip = m_again.initial_point() @@ -73,20 +120,3 @@ def test_minibatch_model(): np.testing.assert_allclose(m_again_res, m_res) # Check that repeated calls give the same result (no more minibatching) np.testing.assert_allclose(m_again_res, m_again_logp_fn(ip)) - - -def test_remove_minibatches(): - data_size = 100 - data = np.zeros((data_size,)) - batch_size = 10 - with Model(coords={"d": range(5)}) as m1: - mb = Minibatch(data, batch_size=batch_size) - mu = Normal("mu", dims="d") - x = Normal("x") - y = Normal("y", x, observed=mb, total_size=100) - - m2 = remove_minibatch(m1) - assert m1.y.shape[0].eval() == batch_size - assert m2.y.shape[0].eval() == data_size - assert m1.coords == m2.coords - assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval() diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index d172c61a4d..5d4e7d2d3d 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,6 +24,7 @@ from pytensor import scan, shared from pytensor.compile import UnusedInputError from pytensor.compile.builders import OpFromGraph +from pytensor.graph import FunctionGraph from pytensor.graph.basic import Variable, equal_computations from pytensor.tensor.subtensor import AdvancedIncSubtensor @@ -46,6 +47,7 @@ replace_rng_nodes, replace_vars_in_graphs, reseed_rngs, + toposort_replace, ) from pymc.vartypes import int_types @@ -785,3 +787,68 @@ def test_pickle_point_func(): np.testing.assert_allclose( point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]}) ) + + +class TestToposortReplace: + @pytest.mark.parametrize("compatible_type", (True, False)) + @pytest.mark.parametrize("num_replacements", (1, 2)) + @pytest.mark.parametrize("rebuild", (True, False)) + def test_horizontal_dependency(self, compatible_type, num_replacements, rebuild): + x = pt.vector("x", shape=(5,)) + y = pt.vector("y", shape=(5,)) + + out1 = pt.exp(x + y) + pt.log(x + y) + out2 = pt.cos(out1) + + new_shape = (5,) if compatible_type else (10,) + new_x = pt.vector("new_x", shape=new_shape) + new_y = pt.vector("new_y", shape=new_shape) + if num_replacements == 1: + replacements = [(y, new_y)] + else: + replacements = [(x, new_x), (y, new_y)] + + fg = FunctionGraph([x, y], [out1, out2], clone=False) + + # If types are incompatible, and we don't rebuild or only replace one of the variables, + # The function should fail + if not compatible_type and (not rebuild or num_replacements == 1): + with pytest.raises((TypeError, ValueError)): + toposort_replace(fg, replacements, rebuild=rebuild) + return + toposort_replace(fg, replacements, rebuild=rebuild) + + if num_replacements == 1: + expected_out1 = pt.exp(x + new_y) + pt.log(x + new_y) + else: + expected_out1 = pt.exp(new_x + new_y) + pt.log(new_x + new_y) + expected_out2 = pt.cos(expected_out1) + assert equal_computations(fg.outputs, [expected_out1, expected_out2]) + + @pytest.mark.parametrize("compatible_type", (True, False)) + @pytest.mark.parametrize("num_replacements", (2, 3)) + @pytest.mark.parametrize("rebuild", (True, False)) + def test_vertical_dependency(self, compatible_type, num_replacements, rebuild): + x = pt.vector("x", shape=(5,)) + a1 = pt.exp(x) + a2 = pt.log(a1) + out = a1 + a2 + + new_x = pt.vector("new_x", shape=(5 if compatible_type else 10,)) + if num_replacements == 2: + replacements = [(x, new_x), (a1, pt.cos(a1)), (a2, pt.sin(a2 + 5))] + else: + replacements = [(a1, pt.cos(pt.exp(new_x))), (a2, pt.sin(a2 + 5))] + + fg = FunctionGraph([x], [out], clone=False) + + if not compatible_type and not rebuild: + with pytest.raises(TypeError): + toposort_replace(fg, replacements, rebuild=rebuild) + return + toposort_replace(fg, replacements, rebuild=rebuild) + + expected_a1 = pt.cos(pt.exp(new_x)) + expected_a2 = pt.sin(pt.log(expected_a1) + 5) + expected_out = expected_a1 + expected_a2 + assert equal_computations(fg.outputs, [expected_out])