Skip to content

Commit 2fc116c

Browse files
committed
WIP Allow rebuilding graph in toposort_replace
1 parent 0fbc7d9 commit 2fc116c

File tree

5 files changed

+238
-78
lines changed

5 files changed

+238
-78
lines changed

pymc/model/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def make_obs_var(
13551355
elif not isinstance(data, Variable):
13561356
data = pt.as_tensor_variable(data, name=name)
13571357

1358-
if total_size:
1358+
if total_size is not None:
13591359
from pymc.variational.minibatch_rv import create_minibatch_rv
13601360

13611361
rv_var = create_minibatch_rv(rv_var, total_size)

pymc/model/transform/minibatch.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from pytensor import Variable
1717
from pytensor.graph import FunctionGraph, ancestors
1818

19-
from build.lib.pymc.variational.minibatch_rv import MinibatchRandomVariable
2019
from pymc import Minibatch, Model
2120
from pymc.data import MinibatchOp
2221
from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph
2322
from pymc.model.transform.basic import parse_vars
2423
from pymc.pytensorf import toposort_replace
24+
from pymc.variational.minibatch_rv import MinibatchRandomVariable
2525

2626

2727
def minibatch_model(
@@ -36,14 +36,16 @@ def minibatch_model(
3636
3737
.. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid.
3838
39+
.. warning:: When minibatch_vars are not specified, all non-scalar data variables will be minibatch. This can be incorrect!
40+
3941
Parameters
4042
----------
4143
model : Model
4244
The original model to transform.
4345
batch_size : int
4446
The minibatch size to use.
4547
minibatch_vars : Sequence of Variable or string, optional
46-
Data variables to convert to minibatch. If None, all data variables with a leading dimension of size None will be minibatched.
48+
Data variables to convert to minibatch. If None, all non scalar data variables will be minibatched.
4749
4850
Returns
4951
-------
@@ -78,9 +80,7 @@ def minibatch_model(
7880

7981
if minibatch_vars is None:
8082
original_minibatch_vars = [
81-
variable
82-
for variable in model.data_vars
83-
if (variable.type.ndim > 0) and (variable.type.shape[0] is None)
83+
variable for variable in model.data_vars if variable.type.ndim > 0
8484
]
8585
else:
8686
original_minibatch_vars = parse_vars(model, minibatch_vars)
@@ -89,11 +89,6 @@ def minibatch_model(
8989
raise ValueError(
9090
f"Cannot minibatch {variable.name} because it is a scalar variable."
9191
)
92-
if variable.type.shape[0] is not None:
93-
raise ValueError(
94-
f"Cannot minibatch {variable.name} because its first dimension is static "
95-
f"(size={variable.type.shape[0]})."
96-
)
9792

9893
# TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed
9994
# shape, but mu from data --> y cannot be minibatched because of sigma.
@@ -124,10 +119,10 @@ def minibatch_model(
124119
minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
125120
minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
126121
toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy)
127-
toposort_replace(minibatch_fgraph, dummy_to_minibatch_var)
122+
toposort_replace(minibatch_fgraph, dummy_to_minibatch_var, rebuild=True)
128123

129124
# Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs
130-
dependent_replacements = {}
125+
dependent_replacements = []
131126
total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...)
132127
vars_to_minibatch_set = set(pre_minibatch_vars)
133128
for model_var in minibatch_fgraph.outputs:
@@ -141,11 +136,10 @@ def minibatch_model(
141136
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
142137
# And conversely other variables do not have that dim
143138
observed_rv = model_var.owner.inputs[0]
144-
dependent_replacements[observed_rv] = create_minibatch_rv(
145-
observed_rv, total_size=total_size
146-
)
139+
minibatch_rv = create_minibatch_rv(observed_rv, total_size=total_size)
140+
dependent_replacements.append((observed_rv, minibatch_rv))
147141

148-
toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items()))
142+
toposort_replace(minibatch_fgraph, dependent_replacements, rebuild=True)
149143

150144
# Finally reintroduce the original data variable outputs
151145
for pre_minibatch_var in pre_minibatch_vars:
@@ -191,11 +185,11 @@ def remove_minibatch(model: Model) -> Model:
191185
fgraph, _ = fgraph_from_model(model)
192186

193187
replacements = []
194-
for var in fgraph.apply_nodes:
195-
if isinstance(var.op, MinibatchOp):
196-
replacements.extend(zip(var.inputs, var.outputs))
197-
elif isinstance(var.op, MinibatchRandomVariable):
198-
replacements.append((var.outputs[0], var.inputs[0]))
188+
for node in fgraph.apply_nodes:
189+
if isinstance(node.op, MinibatchOp):
190+
replacements.extend(zip(node.outputs[:-1], node.inputs[:-1]))
191+
elif isinstance(node.op, MinibatchRandomVariable):
192+
replacements.append((node.outputs[0], node.inputs[0]))
199193

200-
toposort_replace(fgraph, replacements)
194+
toposort_replace(fgraph, replacements, rebuild=True)
201195
return model_from_fgraph(fgraph, mutate_fgraph=True)

pymc/pytensorf.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from collections import deque
1617
from collections.abc import Iterable, Sequence
18+
from itertools import chain
1719
from typing import cast
1820

1921
import numpy as np
@@ -1032,20 +1034,127 @@ def as_symbolic_string(x, **kwargs):
10321034
return StringConstant(stringtype, x)
10331035

10341036

1037+
def _replace_rebuild(
1038+
fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], **kwargs
1039+
) -> FunctionGraph:
1040+
"""Replace variables and rebuild dependent graph if needed.
1041+
1042+
Rebuilding allows for replacements that change the semantics of the graph
1043+
(different types), which may not be possible for all Ops.
1044+
"""
1045+
fg_clients = fgraph.clients
1046+
fg_variables = fgraph.variables
1047+
1048+
def get_client_nodes(vars) -> set[Apply]:
1049+
# Start with the immediate clients of vars
1050+
nodes = set()
1051+
d = list(chain.from_iterable(fg_clients[var] for var in vars if var in fg_variables))
1052+
while d:
1053+
node, _ = d.pop()
1054+
if node in nodes or isinstance(node.op, Output):
1055+
continue
1056+
nodes.add(node)
1057+
# Keep walking to the successor clients
1058+
d.extend(chain.from_iterable(fg_clients[out] for out in node.outputs))
1059+
return nodes
1060+
1061+
repl_dict = dict(replacements)
1062+
root_nodes = {var.owner for var in repl_dict.keys()}
1063+
1064+
# Build sorted queue with all nodes that depend on replaced variables
1065+
topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
1066+
d = deque(
1067+
sorted(
1068+
get_client_nodes(repl_dict.keys()),
1069+
key=lambda node: topo_order[node],
1070+
)
1071+
)
1072+
while d:
1073+
node = d.popleft()
1074+
if node in root_nodes:
1075+
continue
1076+
1077+
new_inputs = [repl_dict.get(i, i) for i in node.inputs]
1078+
if new_inputs == node.inputs:
1079+
continue
1080+
1081+
# We need to remake the node if:
1082+
# 1. The output type depends on an input value
1083+
# 2. Any of the input type changed
1084+
if getattr(node.op, "_output_type_depends_on_input_value", False):
1085+
remake_node = True
1086+
else:
1087+
remake_node = any(
1088+
inp.type != new_inp.type for inp, new_inp in zip(node.inputs, new_inputs)
1089+
)
1090+
1091+
if remake_node:
1092+
new_node = node.clone_with_new_inputs(new_inputs, strict=False)
1093+
fgraph.import_node(new_node, import_missing=True)
1094+
1095+
# We are not always allowed to call `fgraph.replace_all` because the output types may be incompatible
1096+
# We will keep the changes in repl_dict until we can replace a node without remaking it,
1097+
# or we arrive to the end of the graph, in which case we need to replace the FunctionGraph output
1098+
for out, new_out in zip(node.outputs, new_node.outputs):
1099+
new_out.name = out.name
1100+
repl_dict[out] = new_out
1101+
else:
1102+
fgraph.replace_all(tuple(zip(node.inputs, new_inputs)), import_missing=True)
1103+
1104+
# If the FunctionGraph outputs themselves were rebuilt we need to handle them
1105+
for i, (new_output, old_output) in enumerate(
1106+
zip(
1107+
(repl_dict.get(out, out) for out in fgraph.outputs),
1108+
fgraph.outputs,
1109+
)
1110+
):
1111+
if new_output is old_output:
1112+
continue
1113+
fgraph.outputs[i] = new_output
1114+
fgraph.import_var(new_output, import_missing=True)
1115+
fgraph.clients[new_output] = [
1116+
# Output variables have a special Output Op client
1117+
# We need to transfer it to the new output.
1118+
# Any other uses of this output variable will already have been substituted in the loop above,
1119+
# or are part of other outputs we will subsitute next
1120+
(cl.op.make_node(new_output), idx) if isinstance(cl.op, Output) else (cl, idx)
1121+
for cl, idx in fgraph.clients[old_output]
1122+
]
1123+
return fgraph
1124+
1125+
10351126
def toposort_replace(
10361127
fgraph: FunctionGraph,
10371128
replacements: Sequence[tuple[Variable, Variable]],
10381129
reverse: bool = False,
1130+
rebuild: bool = False,
10391131
) -> None:
10401132
"""Replace multiple variables in place in topological order."""
1133+
if rebuild and reverse:
1134+
raise NotImplementedError("reverse rebuild not yet supported")
1135+
10411136
fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())}
10421137
fgraph_toposort[None] = -1 # Variables without owner are not in the toposort
10431138
sorted_replacements = sorted(
10441139
replacements,
10451140
key=lambda pair: fgraph_toposort[pair[0].owner],
10461141
reverse=reverse,
10471142
)
1048-
fgraph.replace_all(sorted_replacements, import_missing=True)
1143+
1144+
if rebuild:
1145+
if len(replacements) > 1:
1146+
# In this case we need to modify the replacements recursively with each other
1147+
sorted_replacements = [list(pairs) for pairs in sorted_replacements]
1148+
for i in range(1, len(replacements)):
1149+
temp_fgraph = FunctionGraph(
1150+
outputs=[repl for _, repl in sorted_replacements[i:]],
1151+
clone=False,
1152+
)
1153+
_replace_rebuild(temp_fgraph, replacements=sorted_replacements[:i])
1154+
sorted_replacements[i][1] = temp_fgraph.outputs[0]
1155+
_replace_rebuild(fgraph, sorted_replacements)
1156+
else:
1157+
fgraph.replace_all(sorted_replacements, import_missing=True)
10491158

10501159

10511160
def normalize_rng_param(rng: None | Variable) -> Variable:

pymc/variational/minibatch_rv.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor import Variable, config
2020
from pytensor.graph import Apply, Op
2121
from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable
22+
from pytensor.tensor.type_other import NoneTypeT
2223

2324
from pymc.logprob.abstract import MeasurableOp, _logprob
2425
from pymc.logprob.basic import logp
@@ -33,7 +34,9 @@ class MinibatchRandomVariable(MeasurableOp, Op):
3334
def make_node(self, rv, *total_size):
3435
rv = as_tensor_variable(rv)
3536
total_size = [
36-
as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst
37+
t
38+
if isinstance(t, Variable)
39+
else (NoneConst if t is None else as_tensor_variable(t, dtype="int64", ndim=0))
3740
for t in total_size
3841
]
3942
assert len(total_size) == rv.ndim
@@ -55,45 +58,67 @@ def perform(self, node, inputs, output_storage):
5558

5659
def create_minibatch_rv(
5760
rv: TensorVariable,
58-
total_size: int | None | Sequence[int | EllipsisType | None],
61+
total_size: int | TensorVariable | Sequence[int | TensorVariable | EllipsisType | None],
5962
) -> TensorVariable:
6063
"""Create variable whose logp is rescaled by total_size."""
64+
rv_ndim_supp = rv.owner.op.ndim_supp
65+
6166
if isinstance(total_size, int):
62-
if rv.ndim <= 1:
63-
total_size = [total_size]
67+
total_size = (total_size, *([None] * rv_ndim_supp))
68+
elif isinstance(total_size, TensorVariable):
69+
if total_size.type.ndim == 0:
70+
total_size = (total_size, *([None] * rv_ndim_supp))
71+
elif total_size.type.ndim == 1:
72+
total_size = tuple(total_size)
6473
else:
65-
missing_ndims = rv.ndim - 1
66-
total_size = [total_size] + [None] * missing_ndims
67-
elif isinstance(total_size, list | tuple):
68-
total_size = list(total_size)
69-
if Ellipsis in total_size:
70-
# Replace Ellipsis by None
71-
if total_size.count(Ellipsis) > 1:
72-
raise ValueError("Only one Ellipsis can be present in total_size")
73-
sep = total_size.index(Ellipsis)
74-
begin = total_size[:sep]
75-
end = total_size[sep + 1 :]
76-
missing_ndims = max((rv.ndim - len(begin) - len(end), 0))
77-
total_size = begin + [None] * missing_ndims + end
78-
if len(total_size) > rv.ndim:
79-
raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}")
80-
else:
81-
raise TypeError(f"Invalid type for total_size: {total_size}")
82-
83-
return cast(TensorVariable, minibatch_rv(rv, *total_size))
84-
85-
86-
def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable:
74+
raise ValueError(
75+
f"Total size must be a 0d or 1d vector got {total_size} with {total_size.type.ndim} dimensions"
76+
)
77+
78+
if not isinstance(total_size, list | tuple):
79+
raise ValueError(f"Invalid type for total_size {total_size}: {type(total_size)}")
80+
81+
if Ellipsis in total_size:
82+
# Replace Ellipsis by None
83+
if total_size.count(Ellipsis) > 1:
84+
raise ValueError("Only one Ellipsis can be present in total_size")
85+
sep = total_size.index(Ellipsis)
86+
begin = total_size[:sep]
87+
end = total_size[sep + 1 :]
88+
missing_ndims = max((rv_ndim_supp - len(begin) - len(end), 0))
89+
total_size = (*begin, *([None] * missing_ndims), *end)
90+
91+
if (len(total_size) - rv_ndim_supp) not in (0, 1):
92+
raise ValueError(
93+
f"Length of total_size {total_size} not compatble with ndim_supp of RV {rv}, "
94+
f"got {len(total_size)} but must be {rv_ndim_supp} or {rv_ndim_supp - 1}"
95+
)
96+
97+
out = minibatch_rv(rv, *total_size)
98+
assert isinstance(out.owner.op, MinibatchRandomVariable)
99+
return cast(TensorVariable, out)
100+
101+
102+
def get_scaling(
103+
total_size: Sequence[TensorVariable], shape: TensorVariable | Sequence[TensorVariable]
104+
) -> TensorVariable:
87105
"""Get scaling constant for logp."""
88106
# mypy doesn't understand we can convert a shape TensorVariable into a tuple
89-
shape = tuple(shape) # type: ignore[assignment]
107+
shape = tuple(shape)
108+
109+
if len(total_size) == (len(shape) - 1):
110+
# This happens when RV has no batch dimensions
111+
# In that case the total_size corresponds to a dummy shape of 1
112+
total_size = (1, *total_size)
113+
114+
assert len(shape) == len(total_size)
90115

91-
# Scalar RV
92-
if len(shape) == 0: # type: ignore[arg-type]
93-
coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0
94-
else:
95-
coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)]
96-
coef = pt.prod(coefs)
116+
coefs = [
117+
size / dim_length
118+
for size, dim_length in zip(total_size, shape)
119+
if not isinstance(size.type, NoneTypeT)
120+
]
121+
coef = pt.prod(coefs) if len(coefs) > 1 else coefs[0]
97122

98123
return pt.cast(coef, dtype=config.floatX)
99124

@@ -102,4 +127,6 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor
102127
def minibatch_rv_logprob(op, values, *inputs, **kwargs):
103128
[value] = values
104129
rv, *total_size = inputs
105-
return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape)
130+
raw_logp = logp(rv, value, **kwargs)
131+
scaled_logp = raw_logp * get_scaling(total_size, raw_logp.shape)
132+
return scaled_logp

0 commit comments

Comments
 (0)