2727
2828from pymc .model .core import modelcontext
2929from pymc .pytensorf import _cheap_eval_mode
30- from pymc .util import VarName , get_default_varnames , get_var_name
30+ from pymc .util import get_default_varnames , get_var_name
3131
3232__all__ = (
3333 "ModelGraph" ,
@@ -173,7 +173,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
173173 }
174174
175175
176- def get_node_type (var_name : VarName , model ) -> NodeType :
176+ def get_node_type (var_name : str , model ) -> NodeType :
177177 """Return the node type of the variable in the model."""
178178 v = model [var_name ]
179179
@@ -242,7 +242,7 @@ def __init__(self, model):
242242 self ._all_vars = {model [var_name ] for var_name in self ._all_var_names }
243243 self .var_list = self .model .named_vars .values ()
244244
245- def get_parent_names (self , var : Variable ) -> set [VarName ]:
245+ def get_parent_names (self , var : Variable ) -> set [str ]:
246246 if var .owner is None :
247247 return set ()
248248
@@ -261,12 +261,12 @@ def _expand(x):
261261 return x .owner .inputs
262262
263263 return {
264- cast (VarName , ancestor .name ) # type: ignore[union-attr]
264+ cast (str , ancestor .name ) # type: ignore[union-attr]
265265 for ancestor in walk (nodes = var .owner .inputs , expand = _expand )
266266 if ancestor in named_vars
267267 }
268268
269- def vars_to_plot (self , var_names : Iterable [VarName ] | None = None ) -> list [VarName ]:
269+ def vars_to_plot (self , var_names : Iterable [str ] | None = None ) -> list [str ]:
270270 if var_names is None :
271271 return self ._all_var_names
272272
@@ -297,12 +297,12 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
297297 return [get_var_name (var ) for var in selected_ancestors ]
298298
299299 def make_compute_graph (
300- self , var_names : Iterable [VarName ] | None = None
301- ) -> dict [VarName , set [VarName ]]:
300+ self , var_names : Iterable [str ] | None = None
301+ ) -> dict [str , set [str ]]:
302302 """Get map of var_name -> set(input var names) for the model."""
303303 model = self .model
304304 named_vars = self ._all_vars
305- input_map : dict [VarName , set [VarName ]] = defaultdict (set )
305+ input_map : dict [str , set [str ]] = defaultdict (set )
306306
307307 var_names_to_plot = self .vars_to_plot (var_names )
308308 for var_name in var_names_to_plot :
@@ -319,15 +319,15 @@ def make_compute_graph(
319319 for ancestor in ancestors ([obs_var ]):
320320 if ancestor not in named_vars :
321321 continue
322- obs_name = cast (VarName , ancestor .name )
322+ obs_name = cast (str , ancestor .name )
323323 input_map [var_name ].discard (obs_name )
324324 input_map [obs_name ].add (var_name )
325325
326326 return input_map
327327
328328 def get_plates (
329329 self ,
330- var_names : Iterable [VarName ] | None = None ,
330+ var_names : Iterable [str ] | None = None ,
331331 ) -> list [Plate ]:
332332 """Rough but surprisingly accurate plate detection.
333333
@@ -337,7 +337,7 @@ def get_plates(
337337 Returns
338338 -------
339339 dict
340- Maps plate labels to the set of ``VarName ``s inside the plate.
340+ Maps plate labels to the set of ``str ``s inside the plate.
341341 """
342342 plates = defaultdict (set )
343343
@@ -389,8 +389,8 @@ def get_plates(
389389
390390 def edges (
391391 self ,
392- var_names : Iterable [VarName ] | None = None ,
393- ) -> list [tuple [VarName , VarName ]]:
392+ var_names : Iterable [str ] | None = None ,
393+ ) -> list [tuple [str , str ]]:
394394 """Get edges between the variables in the model.
395395
396396 Parameters
@@ -405,7 +405,7 @@ def edges(
405405
406406 """
407407 return [
408- (VarName (child .replace (":" , "&" )), VarName (parent .replace (":" , "&" )))
408+ (str (child .replace (":" , "&" )), str (parent .replace (":" , "&" )))
409409 for child , parents in self .make_compute_graph (var_names = var_names ).items ()
410410 for parent in parents
411411 ]
@@ -422,7 +422,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
422422def make_graph (
423423 name : str ,
424424 plates : list [Plate ],
425- edges : list [tuple [VarName , VarName ]],
425+ edges : list [tuple [str , str ]],
426426 formatting : str = "plain" ,
427427 save = None ,
428428 figsize = None ,
@@ -496,7 +496,7 @@ def make_graph(
496496def make_networkx (
497497 name : str ,
498498 plates : list [Plate ],
499- edges : list [tuple [VarName , VarName ]],
499+ edges : list [tuple [str , str ]],
500500 formatting : str = "plain" ,
501501 node_formatters : NodeTypeFormatterMapping | None = None ,
502502 create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
@@ -566,7 +566,7 @@ def make_networkx(
566566def model_to_networkx (
567567 model = None ,
568568 * ,
569- var_names : Iterable [VarName ] | None = None ,
569+ var_names : Iterable [str ] | None = None ,
570570 formatting : str = "plain" ,
571571 node_formatters : NodeTypeFormatterMapping | None = None ,
572572 include_dim_lengths : bool = True ,
@@ -660,7 +660,7 @@ def model_to_networkx(
660660def model_to_graphviz (
661661 model = None ,
662662 * ,
663- var_names : Iterable [VarName ] | None = None ,
663+ var_names : Iterable [str ] | None = None ,
664664 formatting : str = "plain" ,
665665 save : str | None = None ,
666666 figsize : tuple [int , int ] | None = None ,
0 commit comments