Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ def make_obs_var(
elif not isinstance(data, Variable):
data = pt.as_tensor_variable(data, name=name)

if total_size:
if total_size is not None:
from pymc.variational.minibatch_rv import create_minibatch_rv

rv_var = create_minibatch_rv(rv_var, total_size)
Expand Down
26 changes: 1 addition & 25 deletions pymc/model/transform/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
# limitations under the License.
from collections.abc import Sequence

from pytensor import Variable, clone_replace
from pytensor import Variable
from pytensor.graph import ancestors
from pytensor.graph.fg import FunctionGraph

from pymc.data import MinibatchOp
from pymc.model.core import Model
from pymc.model.fgraph import (
ModelObservedRV,
Expand Down Expand Up @@ -60,25 +58,3 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
else:
vars_seq = (vars,)
return [model[var] if isinstance(var, str) else var for var in vars_seq]


def remove_minibatched_nodes(model: Model) -> Model:
"""Remove all uses of pm.Minibatch in the Model."""
fgraph, _ = fgraph_from_model(model)

replacements = {}
for var in fgraph.apply_nodes:
if isinstance(var.op, MinibatchOp):
for inp, out in zip(var.inputs, var.outputs):
replacements[out] = inp

old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined]
# Using `rebuild_strict=False` means all coords, names, and dim information is lost
# So we need to restore it from the old fgraph
new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type]
for old_out, new_out in zip(old_outs, new_outs):
new_out.name = old_out.name
fgraph = FunctionGraph(outputs=new_outs, clone=False)
fgraph._coords = old_coords # type: ignore[attr-defined]
fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined]
return model_from_fgraph(fgraph, mutate_fgraph=True)
195 changes: 195 additions & 0 deletions pymc/model/transform/minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence

from pytensor import Variable
from pytensor.graph import FunctionGraph, ancestors

from pymc import Minibatch, Model
from pymc.data import MinibatchOp
from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph
from pymc.model.transform.basic import parse_vars
from pymc.pytensorf import toposort_replace
from pymc.variational.minibatch_rv import MinibatchRandomVariable


def minibatch_model(
model: Model,
*,
batch_size: int,
minibatch_vars: Sequence[str | Variable] | None = None,
) -> Model:
"""Create a minibatch version of the given Model.

Replaces minibatch_vars data containers with Minibatch views and rescales the logp of dependent observed variables.

.. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid.

.. warning:: When minibatch_vars are not specified, all non-scalar data variables will be minibatch. This can be incorrect!

Parameters
----------
model : Model
The original model to transform.
batch_size : int
The minibatch size to use.
minibatch_vars : Sequence of Variable or string, optional
Data variables to convert to minibatch. If None, all non scalar data variables will be minibatched.

Returns
-------
Model
A new Model with the specified data variables replaced by Minibatch views and dependent observed RVs adjusted accordingly.

Raises
------
ValueError
If any of the specified variables cannot be minibatched (e.g., scalar variables or variables with static leading dimensions), or if dependent variables are Potentials / Unobserved RVs.

Examples
--------
.. code-block:: python

import numpy as np
import pymc as pm
from pymc.model.transform.minibatch import minibatch_model

with pm.Model() as m:
obs_data = pm.Data("obs_data", np.random.normal(size=(100,)))
X_data = pm.Data("X_data", np.random.normal(size=(100, 4)))
beta = pm.Normal("beta", mu=np.pi, dims="feature")

mu = X_data @ beta
y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data)

with minibatch_model(m, batch_size=10) as mb:
pm.fit()
"""
from pymc.variational.minibatch_rv import create_minibatch_rv

if minibatch_vars is None:
original_minibatch_vars = [
variable for variable in model.data_vars if variable.type.ndim > 0
]
else:
original_minibatch_vars = parse_vars(model, minibatch_vars)
for variable in original_minibatch_vars:
if variable.type.ndim == 0:
raise ValueError(
f"Cannot minibatch {variable.name} because it is a scalar variable."
)

# TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed
# shape, but mu from data --> y cannot be minibatched because of sigma.

fgraph, memo = fgraph_from_model(model, inlined_views=True)

pre_minibatch_vars = [memo[var] for var in original_minibatch_vars]
minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size)

# Replace uses of the specified data variables with Minibatch variables
# We need a two-step clone because FunctionGraph can only mutate one variable at a time
# and when there are multiple vars to minibatch you end up replacing the same variable twice recursively
# exampre: out = x + y
# goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)]
# replace x first we get: out = Minibatch(x, y).0 + y
# then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1
# The second replacement of y ends up creating a circular dependency
pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars)
dummy_to_minibatch_var = tuple(
(dummy, minibatch_var)
for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars)
)

# Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves,
# So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs
other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars]
minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False)
minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy)
toposort_replace(minibatch_fgraph, dummy_to_minibatch_var, rebuild=True)

# Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs
dependent_replacements = []
total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...)
vars_to_minibatch_set = set(pre_minibatch_vars)
for model_var in minibatch_fgraph.outputs:
if not (set(ancestors([model_var])) & vars_to_minibatch_set):
continue
if not isinstance(model_var.owner.op, ModelObservedRV):
raise ValueError(
"Minibatching only supports observed RVs depending on minibatched variables. "
f"Found dependent unobserved variable: {model_var.name}."
)
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
# And conversely other variables do not have that dim
observed_rv = model_var.owner.inputs[0]
minibatch_rv = create_minibatch_rv(observed_rv, total_size=total_size)
dependent_replacements.append((observed_rv, minibatch_rv))

toposort_replace(minibatch_fgraph, dependent_replacements, rebuild=True)

# Finally reintroduce the original data variable outputs
for pre_minibatch_var in pre_minibatch_vars:
minibatch_fgraph.add_output(pre_minibatch_var)

return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True)


def remove_minibatch(model: Model) -> Model:
"""Remove all uses of Minibatch data and random variables from the Model.

Parameters
----------
model : Model
The original model to transform.

Returns
-------
Model
A new Model with all Minibatch data variables and MinibatchRVs replaced by their original counterparts.

Examples
--------
.. code-block:: python

import pymc as pm
from pymc.model.transform.minibatch import undo_minibatch

with pm.Model() as mb:
X_data = pm.Data("X_data", [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
obs_data = pm.Data("obs_data", [1, 2, 3, 4, 5])
minibatch_X_data, minibatch_obs_data = pm.Minibatch(X_data, obs_data, batch_size=3)

beta = pm.Normal("beta", shape=(2,))
mu = minibatch_X_data @ beta
y = pm.Normal("y", mu=mu, sigma=1, observed=minibatch_obs_data, total_size=(5,))

with undo_minibatch(mb) as m:
idata = pm.sample_prior_predictive()
assert idata.prior["y"].shape[-1] == 5 # Original data size restored

"""
fgraph, _ = fgraph_from_model(model)

replacements = []
for node in fgraph.apply_nodes:
if isinstance(node.op, MinibatchOp):
replacements.extend(zip(node.outputs[:-1], node.inputs[:-1]))
elif isinstance(node.op, MinibatchRandomVariable):
replacements.append((node.outputs[0], node.inputs[0]))

toposort_replace(fgraph, replacements, rebuild=True)
return model_from_fgraph(fgraph, mutate_fgraph=True)
111 changes: 110 additions & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
import warnings

from collections import deque
from collections.abc import Iterable, Sequence
from itertools import chain
from typing import cast

import numpy as np
Expand Down Expand Up @@ -1032,20 +1034,127 @@ def as_symbolic_string(x, **kwargs):
return StringConstant(stringtype, x)


def _replace_rebuild(
fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], **kwargs
) -> FunctionGraph:
"""Replace variables and rebuild dependent graph if needed.

Rebuilding allows for replacements that change the semantics of the graph
(different types), which may not be possible for all Ops.
"""
fg_clients = fgraph.clients
fg_variables = fgraph.variables

def get_client_nodes(vars) -> set[Apply]:
# Start with the immediate clients of vars
nodes = set()
d = list(chain.from_iterable(fg_clients[var] for var in vars if var in fg_variables))
while d:
node, _ = d.pop()
if node in nodes or isinstance(node.op, Output):
continue
nodes.add(node)
# Keep walking to the successor clients
d.extend(chain.from_iterable(fg_clients[out] for out in node.outputs))
return nodes

repl_dict = dict(replacements)
root_nodes = {var.owner for var in repl_dict.keys()}

# Build sorted queue with all nodes that depend on replaced variables
topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
d = deque(
sorted(
get_client_nodes(repl_dict.keys()),
key=lambda node: topo_order[node],
)
)
while d:
node = d.popleft()
if node in root_nodes:
continue

new_inputs = [repl_dict.get(i, i) for i in node.inputs]
if new_inputs == node.inputs:
continue

# We need to remake the node if:
# 1. The output type depends on an input value
# 2. Any of the input type changed
if getattr(node.op, "_output_type_depends_on_input_value", False):
remake_node = True
else:
remake_node = any(
inp.type != new_inp.type for inp, new_inp in zip(node.inputs, new_inputs)
)

if remake_node:
new_node = node.clone_with_new_inputs(new_inputs, strict=False)
fgraph.import_node(new_node, import_missing=True)

# We are not always allowed to call `fgraph.replace_all` because the output types may be incompatible
# We will keep the changes in repl_dict until we can replace a node without remaking it,
# or we arrive to the end of the graph, in which case we need to replace the FunctionGraph output
for out, new_out in zip(node.outputs, new_node.outputs):
new_out.name = out.name
repl_dict[out] = new_out
else:
fgraph.replace_all(tuple(zip(node.inputs, new_inputs)), import_missing=True)

# If the FunctionGraph outputs themselves were rebuilt we need to handle them
for i, (new_output, old_output) in enumerate(
zip(
(repl_dict.get(out, out) for out in fgraph.outputs),
fgraph.outputs,
)
):
if new_output is old_output:
continue
fgraph.outputs[i] = new_output
fgraph.import_var(new_output, import_missing=True)
fgraph.clients[new_output] = [
# Output variables have a special Output Op client
# We need to transfer it to the new output.
# Any other uses of this output variable will already have been substituted in the loop above,
# or are part of other outputs we will subsitute next
(cl.op.make_node(new_output), idx) if isinstance(cl.op, Output) else (cl, idx)
for cl, idx in fgraph.clients[old_output]
]
return fgraph


def toposort_replace(
fgraph: FunctionGraph,
replacements: Sequence[tuple[Variable, Variable]],
reverse: bool = False,
rebuild: bool = False,
) -> None:
"""Replace multiple variables in place in topological order."""
if rebuild and reverse:
raise NotImplementedError("reverse rebuild not yet supported")

fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())}
fgraph_toposort[None] = -1 # Variables without owner are not in the toposort
sorted_replacements = sorted(
replacements,
key=lambda pair: fgraph_toposort[pair[0].owner],
reverse=reverse,
)
fgraph.replace_all(sorted_replacements, import_missing=True)

if rebuild:
if len(replacements) > 1:
# In this case we need to modify the replacements recursively with each other
sorted_replacements = [list(pairs) for pairs in sorted_replacements]
for i in range(1, len(replacements)):
temp_fgraph = FunctionGraph(
outputs=[repl for _, repl in sorted_replacements[i:]],
clone=False,
)
_replace_rebuild(temp_fgraph, replacements=sorted_replacements[:i])
sorted_replacements[i][1] = temp_fgraph.outputs[0]
_replace_rebuild(fgraph, sorted_replacements)
else:
fgraph.replace_all(sorted_replacements, import_missing=True)


def normalize_rng_param(rng: None | Variable) -> Variable:
Expand Down
Loading
Loading