Skip to content

Commit d487cc9

Browse files
Remove VarName from codebase
- Remove VarName NewType definition from util.py - Replace all VarName type hints with str - Simplify get_var_name function to use var.name directly - Update imports in model_graph.py and model/core.py - Fix all type annotations and function signatures Resolves #7843
1 parent 3a0186e commit d487cc9

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

pymc/model/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
)
6868
from pymc.util import (
6969
UNSET,
70-
VarName,
7170
WithMemoization,
7271
_UnsetType,
7372
get_transformed_name,
@@ -1968,7 +1967,7 @@ def debug_parameters(rv):
19681967
def to_graphviz(
19691968
self,
19701969
*,
1971-
var_names: Iterable[VarName] | None = None,
1970+
var_names: Iterable[str] | None = None,
19721971
formatting: str = "plain",
19731972
save: str | None = None,
19741973
figsize: tuple[int, int] | None = None,
@@ -2172,7 +2171,7 @@ def compile_fn(
21722171
)
21732172

21742173

2175-
def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]:
2174+
def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]:
21762175
"""Build a point.
21772176
21782177
Uses same args as dict() does.

pymc/model_graph.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pymc.model.core import modelcontext
2929
from pymc.pytensorf import _cheap_eval_mode
30-
from pymc.util import VarName, get_default_varnames, get_var_name
30+
from pymc.util import get_default_varnames, get_var_name
3131

3232
__all__ = (
3333
"ModelGraph",
@@ -173,7 +173,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
173173
}
174174

175175

176-
def get_node_type(var_name: VarName, model) -> NodeType:
176+
def get_node_type(var_name: str, model) -> NodeType:
177177
"""Return the node type of the variable in the model."""
178178
v = model[var_name]
179179

@@ -242,7 +242,7 @@ def __init__(self, model):
242242
self._all_vars = {model[var_name] for var_name in self._all_var_names}
243243
self.var_list = self.model.named_vars.values()
244244

245-
def get_parent_names(self, var: Variable) -> set[VarName]:
245+
def get_parent_names(self, var: Variable) -> set[str]:
246246
if var.owner is None:
247247
return set()
248248

@@ -261,12 +261,12 @@ def _expand(x):
261261
return x.owner.inputs
262262

263263
return {
264-
cast(VarName, ancestor.name) # type: ignore[union-attr]
264+
cast(str, ancestor.name) # type: ignore[union-attr]
265265
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
266266
if ancestor in named_vars
267267
}
268268

269-
def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
269+
def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]:
270270
if var_names is None:
271271
return self._all_var_names
272272

@@ -297,12 +297,12 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
297297
return [get_var_name(var) for var in selected_ancestors]
298298

299299
def make_compute_graph(
300-
self, var_names: Iterable[VarName] | None = None
301-
) -> dict[VarName, set[VarName]]:
300+
self, var_names: Iterable[str] | None = None
301+
) -> dict[str, set[str]]:
302302
"""Get map of var_name -> set(input var names) for the model."""
303303
model = self.model
304304
named_vars = self._all_vars
305-
input_map: dict[VarName, set[VarName]] = defaultdict(set)
305+
input_map: dict[str, set[str]] = defaultdict(set)
306306

307307
var_names_to_plot = self.vars_to_plot(var_names)
308308
for var_name in var_names_to_plot:
@@ -319,15 +319,15 @@ def make_compute_graph(
319319
for ancestor in ancestors([obs_var]):
320320
if ancestor not in named_vars:
321321
continue
322-
obs_name = cast(VarName, ancestor.name)
322+
obs_name = cast(str, ancestor.name)
323323
input_map[var_name].discard(obs_name)
324324
input_map[obs_name].add(var_name)
325325

326326
return input_map
327327

328328
def get_plates(
329329
self,
330-
var_names: Iterable[VarName] | None = None,
330+
var_names: Iterable[str] | None = None,
331331
) -> list[Plate]:
332332
"""Rough but surprisingly accurate plate detection.
333333
@@ -337,7 +337,7 @@ def get_plates(
337337
Returns
338338
-------
339339
dict
340-
Maps plate labels to the set of ``VarName``s inside the plate.
340+
Maps plate labels to the set of ``str``s inside the plate.
341341
"""
342342
plates = defaultdict(set)
343343

@@ -389,8 +389,8 @@ def get_plates(
389389

390390
def edges(
391391
self,
392-
var_names: Iterable[VarName] | None = None,
393-
) -> list[tuple[VarName, VarName]]:
392+
var_names: Iterable[str] | None = None,
393+
) -> list[tuple[str, str]]:
394394
"""Get edges between the variables in the model.
395395
396396
Parameters
@@ -405,7 +405,7 @@ def edges(
405405
406406
"""
407407
return [
408-
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
408+
(str(child.replace(":", "&")), str(parent.replace(":", "&")))
409409
for child, parents in self.make_compute_graph(var_names=var_names).items()
410410
for parent in parents
411411
]
@@ -422,7 +422,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
422422
def make_graph(
423423
name: str,
424424
plates: list[Plate],
425-
edges: list[tuple[VarName, VarName]],
425+
edges: list[tuple[str, str]],
426426
formatting: str = "plain",
427427
save=None,
428428
figsize=None,
@@ -496,7 +496,7 @@ def make_graph(
496496
def make_networkx(
497497
name: str,
498498
plates: list[Plate],
499-
edges: list[tuple[VarName, VarName]],
499+
edges: list[tuple[str, str]],
500500
formatting: str = "plain",
501501
node_formatters: NodeTypeFormatterMapping | None = None,
502502
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
@@ -566,7 +566,7 @@ def make_networkx(
566566
def model_to_networkx(
567567
model=None,
568568
*,
569-
var_names: Iterable[VarName] | None = None,
569+
var_names: Iterable[str] | None = None,
570570
formatting: str = "plain",
571571
node_formatters: NodeTypeFormatterMapping | None = None,
572572
include_dim_lengths: bool = True,
@@ -660,7 +660,7 @@ def model_to_networkx(
660660
def model_to_graphviz(
661661
model=None,
662662
*,
663-
var_names: Iterable[VarName] | None = None,
663+
var_names: Iterable[str] | None = None,
664664
formatting: str = "plain",
665665
save: str | None = None,
666666
figsize: tuple[int, int] | None = None,

pymc/util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import namedtuple
1919
from collections.abc import Sequence
2020
from copy import deepcopy
21-
from typing import NewType, cast
21+
from typing import cast
2222

2323
import arviz
2424
import cloudpickle
@@ -31,7 +31,7 @@
3131

3232
from pymc.exceptions import BlockModelAccessError
3333

34-
VarName = NewType("VarName", str)
34+
3535

3636

3737
class _UnsetType:
@@ -214,9 +214,9 @@ def get_default_varnames(var_iterator, include_transformed):
214214
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]
215215

216216

217-
def get_var_name(var) -> VarName:
217+
def get_var_name(var) -> str:
218218
"""Get an appropriate, plain variable name for a variable."""
219-
return VarName(str(getattr(var, "name", var)))
219+
return var.name if var.name is not None else str(var)
220220

221221

222222
def get_transformed(z):

0 commit comments

Comments
 (0)