diff --git a/mesa/examples/advanced/epstein_civil_violence/app.py b/mesa/examples/advanced/epstein_civil_violence/app.py index 0eba101af98..376a4eec643 100644 --- a/mesa/examples/advanced/epstein_civil_violence/app.py +++ b/mesa/examples/advanced/epstein_civil_violence/app.py @@ -64,8 +64,11 @@ def post_process(ax): ) epstein_model = EpsteinCivilViolence() -renderer = SpaceRenderer(epstein_model, backend="matplotlib") -renderer.draw_agents(citizen_cop_portrayal) +renderer = SpaceRenderer(epstein_model, backend="matplotlib").setup_agents( + citizen_cop_portrayal +) +# Specifically, avoid drawing the grid to hide the grid lines. +renderer.draw_agents() renderer.post_process = post_process page = SolaraViz( diff --git a/mesa/examples/advanced/pd_grid/app.py b/mesa/examples/advanced/pd_grid/app.py index 580b582af47..efc4f409760 100644 --- a/mesa/examples/advanced/pd_grid/app.py +++ b/mesa/examples/advanced/pd_grid/app.py @@ -45,7 +45,11 @@ def pd_agent_portrayal(agent): # Initialize model initial_model = PdGrid() # Create grid and agent visualization component using Altair -renderer = SpaceRenderer(initial_model, backend="altair").render(pd_agent_portrayal) +renderer = ( + SpaceRenderer(initial_model, backend="altair") + .setup_agents(pd_agent_portrayal) + .render() +) # Create visualization with all components page = SolaraViz( diff --git a/mesa/examples/advanced/sugarscape_g1mt/app.py b/mesa/examples/advanced/sugarscape_g1mt/app.py index 0fb767e1e4e..33eed8b3474 100644 --- a/mesa/examples/advanced/sugarscape_g1mt/app.py +++ b/mesa/examples/advanced/sugarscape_g1mt/app.py @@ -57,11 +57,16 @@ def post_process(chart): # Here, the renderer uses the Altair backend, while the plot components # use the Matplotlib backend. # Both can be mixed and matched to enhance the visuals of your model. -renderer = SpaceRenderer(model, backend="altair").render( - agent_portrayal=agent_portrayal, - propertylayer_portrayal=propertylayer_portrayal, - post_process=post_process, +renderer = ( + SpaceRenderer(model, backend="altair") + .setup_agents(agent_portrayal) + .setup_propertylayer(propertylayer_portrayal) ) +# Specifically, avoid drawing the grid to hide the grid lines. +renderer.draw_agents() +renderer.draw_propertylayer() + +renderer.post_process = post_process # Note: It is advised to switch the pages after pausing the model # on the Solara dashboard. diff --git a/mesa/examples/advanced/wolf_sheep/app.py b/mesa/examples/advanced/wolf_sheep/app.py index 8b33b0855cd..e17080e53e9 100644 --- a/mesa/examples/advanced/wolf_sheep/app.py +++ b/mesa/examples/advanced/wolf_sheep/app.py @@ -84,9 +84,9 @@ def post_process_lines(ax): renderer = SpaceRenderer( model, backend="matplotlib", -) -renderer.draw_agents(wolf_sheep_portrayal) +).setup_agents(wolf_sheep_portrayal) renderer.post_process = post_process_space +renderer.draw_agents() page = SolaraViz( model, diff --git a/mesa/examples/basic/boid_flockers/app.py b/mesa/examples/basic/boid_flockers/app.py index c008e1bce6e..392543f1b52 100644 --- a/mesa/examples/basic/boid_flockers/app.py +++ b/mesa/examples/basic/boid_flockers/app.py @@ -70,10 +70,14 @@ def boid_draw(agent): model = BoidFlockers() # Quickest way to visualize grid along with agents or property layers. -renderer = SpaceRenderer( - model, - backend="matplotlib", -).render(agent_portrayal=boid_draw) +renderer = ( + SpaceRenderer( + model, + backend="matplotlib", + ) + .setup_agents(boid_draw) + .render() +) page = SolaraViz( model, diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py index 464708536ca..e3eb0c5381f 100644 --- a/mesa/examples/basic/boltzmann_wealth_model/app.py +++ b/mesa/examples/basic/boltzmann_wealth_model/app.py @@ -60,10 +60,15 @@ def post_process(chart): # It builds the visualization in layers, first drawing the grid structure, # and then drawing the agents on top. It uses a specified backend # (like "altair" or "matplotlib") for creating the plots. -renderer = SpaceRenderer(model, backend="altair") -# Can customize the grid appearance. -renderer.draw_structure(grid_color="black", grid_dash=[6, 2], grid_opacity=0.3) -renderer.draw_agents(agent_portrayal=agent_portrayal, cmap="viridis", vmin=0, vmax=10) + +renderer = ( + SpaceRenderer(model, backend="altair") + .setup_structure( # To customize the grid appearance. + grid_color="black", grid_dash=[6, 2], grid_opacity=0.3 + ) + .setup_agents(agent_portrayal, cmap="viridis", vmin=0, vmax=10) +) +renderer.render() # The post_process function is used to modify the Altair chart after it has been created. # It can be used to add legends, colorbars, or other visual elements. diff --git a/mesa/examples/basic/conways_game_of_life/app.py b/mesa/examples/basic/conways_game_of_life/app.py index 618b72dbaf7..ac4f1134b0a 100644 --- a/mesa/examples/basic/conways_game_of_life/app.py +++ b/mesa/examples/basic/conways_game_of_life/app.py @@ -55,10 +55,10 @@ def post_process(ax): # Create initial model instance model1 = ConwaysGameOfLife() -renderer = SpaceRenderer(model1, backend="matplotlib") +renderer = SpaceRenderer(model1, backend="matplotlib").setup_agents(agent_portrayal) # In this case the renderer only draws the agents because we just want to observe # the state of the agents, not the structure of the grid. -renderer.draw_agents(agent_portrayal=agent_portrayal) +renderer.draw_agents() renderer.post_process = post_process # Create the SolaraViz page. This will automatically create a server and display the diff --git a/mesa/examples/basic/schelling/app.py b/mesa/examples/basic/schelling/app.py index 32395a66d5e..cebf2a64f1e 100644 --- a/mesa/examples/basic/schelling/app.py +++ b/mesa/examples/basic/schelling/app.py @@ -73,12 +73,11 @@ def agent_portrayal(agent): # Note: Models with images as markers are very performance intensive. model1 = Schelling() -renderer = SpaceRenderer(model1, backend="matplotlib") +renderer = SpaceRenderer(model1, backend="matplotlib").setup_agents(agent_portrayal) # Here we use renderer.render() to render the agents and grid in one go. # This function always renders the grid and then renders the agents or -# property layers on top of it if specified. It also supports passing the -# post_process function to fine-tune the plot after rendering in itself. -renderer.render(agent_portrayal=agent_portrayal) +# property layers on top of it if specified. +renderer.render() HappyPlot = make_plot_component({"happy": "tab:green"}) diff --git a/mesa/examples/basic/virus_on_network/app.py b/mesa/examples/basic/virus_on_network/app.py index b9de98eae8a..5d17398245d 100644 --- a/mesa/examples/basic/virus_on_network/app.py +++ b/mesa/examples/basic/virus_on_network/app.py @@ -109,12 +109,16 @@ def post_process_lineplot(chart): model1 = VirusOnNetwork() -renderer = SpaceRenderer(model1, backend="altair") -renderer.draw_structure( - node_kwargs={"color": "black", "filled": False, "strokeWidth": 5}, - edge_kwargs={"strokeDash": [6, 1]}, -) # Do this to draw the underlying network and customize it -renderer.draw_agents(agent_portrayal) +renderer = ( + SpaceRenderer(model1, backend="altair") + .setup_structure( # Do this to draw the underlying network and customize it + node_kwargs={"color": "black", "filled": False, "strokeWidth": 5}, + edge_kwargs={"strokeDash": [6, 1]}, + ) + .setup_agents(agent_portrayal) +) + +renderer.render() # Plot components can also be in altair and support post_process StatePlot = make_plot_component( diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 3bd33e5b432..87b33ea2a9f 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -269,19 +269,12 @@ def SpaceRendererComponent( for artist in itertools.chain.from_iterable(all_artists): artist.remove() - # Draw the space structure if specified if renderer.space_mesh: - renderer.draw_structure(**renderer.space_kwargs) - - # Draw agents if specified + renderer.draw_structure() if renderer.agent_mesh: - renderer.draw_agents( - agent_portrayal=renderer.agent_portrayal, **renderer.agent_kwargs - ) - - # Draw property layers if specified + renderer.draw_agents() if renderer.propertylayer_mesh: - renderer.draw_propertylayer(renderer.propertylayer_portrayal) + renderer.draw_propertylayer() # Update the fig every time frame if dependencies: @@ -306,15 +299,11 @@ def SpaceRendererComponent( propertylayer = renderer.propertylayer_mesh or None if renderer.space_mesh: - structure = renderer.draw_structure(**renderer.space_kwargs) + structure = renderer.draw_structure() if renderer.agent_mesh: - agents = renderer.draw_agents( - renderer.agent_portrayal, **renderer.agent_kwargs - ) + agents = renderer.draw_agents() if renderer.propertylayer_mesh: - propertylayer = renderer.draw_propertylayer( - renderer.propertylayer_portrayal - ) + propertylayer = renderer.draw_propertylayer() spatial_charts_list = [ chart for chart in [structure, propertylayer, agents] if chart diff --git a/mesa/visualization/space_drawers.py b/mesa/visualization/space_drawers.py index da5a4079533..e17a420f2b5 100644 --- a/mesa/visualization/space_drawers.py +++ b/mesa/visualization/space_drawers.py @@ -82,12 +82,12 @@ def __init__(self, space: OrthogonalGrid): self.viz_ymin = -0.5 self.viz_ymax = self.space.height - 0.5 - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the orthogonal grid using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Additional keyword arguments for styling. + **draw_space_kwargs: Additional keyword arguments for styling. Examples: figsize=(10, 10), color="blue", linewidth=2. @@ -96,8 +96,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs): The modified axes object """ fig_kwargs = { - "figsize": space_kwargs.pop("figsize", (8, 8)), - "dpi": space_kwargs.pop("dpi", 100), + "figsize": draw_space_kwargs.pop("figsize", (8, 8)), + "dpi": draw_space_kwargs.pop("dpi", 100), } if ax is None: @@ -110,7 +110,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): "linewidth": 1, "alpha": 1, } - line_kwargs.update(space_kwargs) + line_kwargs.update(draw_space_kwargs) ax.set_xlim(self.viz_xmin, self.viz_xmax) ax.set_ylim(self.viz_ymin, self.viz_ymax) @@ -123,13 +123,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the orthogonal grid using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Additional keyword arguments for styling the chart. + **draw_chart_kwargs: Additional keyword arguments for styling the chart. Examples: width=500, height=500, title="Grid". @@ -139,12 +139,12 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): """ # for axis and grid styling axis_kwargs = { - "xlabel": chart_kwargs.pop("xlabel", "X"), - "ylabel": chart_kwargs.pop("ylabel", "Y"), - "grid_color": chart_kwargs.pop("grid_color", "lightgray"), - "grid_dash": chart_kwargs.pop("grid_dash", [2, 2]), - "grid_width": chart_kwargs.pop("grid_width", 1), - "grid_opacity": chart_kwargs.pop("grid_opacity", 1), + "xlabel": draw_chart_kwargs.pop("xlabel", "X"), + "ylabel": draw_chart_kwargs.pop("ylabel", "Y"), + "grid_color": draw_chart_kwargs.pop("grid_color", "lightgray"), + "grid_dash": draw_chart_kwargs.pop("grid_dash", [2, 2]), + "grid_width": draw_chart_kwargs.pop("grid_width", 1), + "grid_opacity": draw_chart_kwargs.pop("grid_opacity", 1), } # for chart properties @@ -152,7 +152,7 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): "width": chart_width, "height": chart_height, } - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) chart = ( alt.Chart(pd.DataFrame([{}])) @@ -263,12 +263,12 @@ def _get_unique_edges(self): edges.add(edge) return edges - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the hexagonal grid using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Additional keyword arguments for styling. + **draw_space_kwargs: Additional keyword arguments for styling. Examples: figsize=(8, 8), color="red", alpha=0.5. @@ -277,8 +277,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs): The modified axes object """ fig_kwargs = { - "figsize": space_kwargs.pop("figsize", (8, 8)), - "dpi": space_kwargs.pop("dpi", 100), + "figsize": draw_space_kwargs.pop("figsize", (8, 8)), + "dpi": draw_space_kwargs.pop("dpi", 100), } if ax is None: @@ -290,7 +290,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): "linewidth": 1, "alpha": 0.8, } - line_kwargs.update(space_kwargs) + line_kwargs.update(draw_space_kwargs) ax.set_xlim(self.viz_xmin, self.viz_xmax) ax.set_ylim(self.viz_ymin, self.viz_ymax) @@ -300,13 +300,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): ax.add_collection(LineCollection(list(edges), **line_kwargs)) return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the hexagonal grid using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Additional keyword arguments for styling the chart. + **draw_chart_kwargs: Additional keyword arguments for styling the chart. Examples: * Line properties like color, strokeDash, strokeWidth, opacity. @@ -316,17 +316,17 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): Altair chart object representing the hexagonal grid. """ mark_kwargs = { - "color": chart_kwargs.pop("color", "black"), - "strokeDash": chart_kwargs.pop("strokeDash", [2, 2]), - "strokeWidth": chart_kwargs.pop("strokeWidth", 1), - "opacity": chart_kwargs.pop("opacity", 0.8), + "color": draw_chart_kwargs.pop("color", "black"), + "strokeDash": draw_chart_kwargs.pop("strokeDash", [2, 2]), + "strokeWidth": draw_chart_kwargs.pop("strokeWidth", 1), + "opacity": draw_chart_kwargs.pop("opacity", 0.8), } chart_props = { "width": chart_width, "height": chart_height, } - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) edge_data = [] edges = self._get_unique_edges() @@ -400,12 +400,12 @@ def __init__( self.viz_ymin = ymin - height / 20 self.viz_ymax = ymax + height / 20 - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the network using matplotlib. Args: ax: Matplotlib axes object to draw on. - **space_kwargs: Dictionaries of keyword arguments for styling. + **draw_space_kwargs: Dictionaries of keyword arguments for styling. Can also handle zorder for both nodes and edges if passed. * ``node_kwargs``: A dict passed to nx.draw_networkx_nodes. * ``edge_kwargs``: A dict passed to nx.draw_networkx_edges. @@ -423,8 +423,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs): node_kwargs = {"alpha": 0.5} edge_kwargs = {"alpha": 0.5, "style": "--"} - node_kwargs.update(space_kwargs.get("node_kwargs", {})) - edge_kwargs.update(space_kwargs.get("edge_kwargs", {})) + node_kwargs.update(draw_space_kwargs.get("node_kwargs", {})) + edge_kwargs.update(draw_space_kwargs.get("edge_kwargs", {})) node_zorder = node_kwargs.pop("zorder", 1) edge_zorder = edge_kwargs.pop("zorder", 0) @@ -443,13 +443,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the network using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Dictionaries for styling the chart. + **draw_chart_kwargs: Dictionaries for styling the chart. * ``node_kwargs``: A dict of properties for the node's mark_point. * ``edge_kwargs``: A dict of properties for the edge's mark_rule. * Other kwargs (e.g., title, width) are passed to chart.properties(). @@ -474,14 +474,14 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): node_mark_kwargs = {"filled": True, "opacity": 0.5, "size": 500} edge_mark_kwargs = {"opacity": 0.5, "strokeDash": [5, 3]} - node_mark_kwargs.update(chart_kwargs.pop("node_kwargs", {})) - edge_mark_kwargs.update(chart_kwargs.pop("edge_kwargs", {})) + node_mark_kwargs.update(draw_chart_kwargs.pop("node_kwargs", {})) + edge_mark_kwargs.update(draw_chart_kwargs.pop("edge_kwargs", {})) - chart_kwargs = { + chart_props = { "width": chart_width, "height": chart_height, } - chart_kwargs.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) edge_plot = ( alt.Chart(edge_positions) @@ -510,8 +510,8 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): chart = edge_plot + node_plot - if chart_kwargs: - chart = chart.properties(**chart_kwargs) + if chart_props: + chart = chart.properties(**chart_props) return chart @@ -540,12 +540,12 @@ def __init__(self, space: ContinuousSpace): self.viz_ymin = self.space.y_min - y_padding self.viz_ymax = self.space.y_max + y_padding - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the continuous space using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Keyword arguments for styling the axis frame. + **draw_space_kwargs: Keyword arguments for styling the axis frame. Examples: linewidth=3, color="green" @@ -558,7 +558,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): border_style = "solid" if not self.space.torus else (0, (5, 10)) spine_kwargs = {"linewidth": 1.5, "color": "black", "linestyle": border_style} - spine_kwargs.update(space_kwargs) + spine_kwargs.update(draw_space_kwargs) for spine in ax.spines.values(): spine.set(**spine_kwargs) @@ -568,20 +568,20 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the continuous space using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Keyword arguments for styling the chart's view properties. + **draw_chart_kwargs: Keyword arguments for styling the chart's view properties. See Altair's documentation for `configure_view`. Returns: An Altair Chart object representing the space. """ chart_props = {"width": chart_width, "height": chart_height} - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) chart = ( alt.Chart(pd.DataFrame([{}])) @@ -712,12 +712,12 @@ def _get_clipped_segments(self): return final_segments, clip_box - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the Voronoi diagram using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Keyword arguments passed to matplotlib's LineCollection. + **draw_space_kwargs: Keyword arguments passed to matplotlib's LineCollection. Examples: lw=2, alpha=0.5, colors='red' @@ -736,7 +736,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): if final_segments: # Define default styles for the plot style_args = {"colors": "k", "linestyle": "dotted", "lw": 1} - style_args.update(space_kwargs) + style_args.update(draw_space_kwargs) # Create the LineCollection with the final styles lc = LineCollection(final_segments, **style_args) @@ -744,13 +744,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the Voronoi diagram using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Additional keyword arguments for styling the chart. + **draw_chart_kwargs: Additional keyword arguments for styling the chart. Examples: * Line properties like color, strokeDash, strokeWidth, opacity. @@ -771,14 +771,14 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): # Define default properties for the mark mark_kwargs = { - "color": chart_kwargs.pop("color", "black"), - "strokeDash": chart_kwargs.pop("strokeDash", [2, 2]), - "strokeWidth": chart_kwargs.pop("strokeWidth", 1), - "opacity": chart_kwargs.pop("opacity", 0.8), + "color": draw_chart_kwargs.pop("color", "black"), + "strokeDash": draw_chart_kwargs.pop("strokeDash", [2, 2]), + "strokeWidth": draw_chart_kwargs.pop("strokeWidth", 1), + "opacity": draw_chart_kwargs.pop("opacity", 0.8), } chart_props = {"width": chart_width, "height": chart_height} - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) chart = ( alt.Chart(df) diff --git a/mesa/visualization/space_renderer.py b/mesa/visualization/space_renderer.py index a517a70b9da..c60bb9c2214 100644 --- a/mesa/visualization/space_renderer.py +++ b/mesa/visualization/space_renderer.py @@ -4,6 +4,8 @@ backends, supporting various space types and visualization components. """ +from __future__ import annotations + import contextlib import warnings from collections.abc import Callable @@ -63,10 +65,17 @@ def __init__( self.space = getattr(model, "grid", getattr(model, "space", None)) self.space_drawer = self._get_space_drawer() + self.space_mesh = None self.agent_mesh = None self.propertylayer_mesh = None + self.draw_agent_kwargs = {} + self.draw_space_kwargs = {} + + self.agent_portrayal = None + self.propertylayer_portrayal = None + self.post_process_func = None # Keep track of whether post-processing has been applied # to avoid multiple applications on the same axis. @@ -161,54 +170,122 @@ def _map_coordinates(self, arguments): return mapped_arguments + def setup_structure(self, **kwargs) -> SpaceRenderer: + """Setup the space structure without drawing. + + Args: + **kwargs: Additional keyword arguments for the setup function. + Checkout respective `SpaceDrawer` class on details how to pass **kwargs. + + Returns: + SpaceRenderer: The current instance for method chaining. + """ + self.draw_space_kwargs = kwargs + self.space_mesh = None + + return self + + def setup_agents(self, agent_portrayal: Callable, **kwargs) -> SpaceRenderer: + """Setup agents on the space without drawing. + + Args: + agent_portrayal (Callable): Function that takes an agent and returns AgentPortrayalStyle. + **kwargs: Additional keyword arguments for the setup function. + Checkout respective `SpaceDrawer` class on details how to pass **kwargs. + + Returns: + SpaceRenderer: The current instance for method chaining. + """ + self.agent_portrayal = agent_portrayal + self.draw_agent_kwargs = kwargs + self.agent_mesh = None + + return self + + def setup_propertylayer( + self, propertylayer_portrayal: Callable | dict + ) -> SpaceRenderer: + """Setup property layers on the space without drawing. + + Args: + propertylayer_portrayal (Callable | dict): Function that returns PropertyLayerStyle + or dict with portrayal parameters. + + Returns: + SpaceRenderer: The current instance for method chaining. + """ + self.propertylayer_portrayal = propertylayer_portrayal + self.propertylayer_mesh = None + + return self + def draw_structure(self, **kwargs): """Draw the space structure. Args: - **kwargs: Additional keyword arguments for the drawing function. - Checkout respective `SpaceDrawer` class on details how to pass **kwargs. + **kwargs: (Deprecated) Additional keyword arguments for drawing. + Use setup_structure() instead. Returns: The visual representation of the space structure. """ - # Store space_kwargs for internal use - self.space_kwargs = kwargs + if kwargs: + warnings.warn( + "Passing kwargs to draw_structure() is deprecated. " + "Use setup_structure(**kwargs) before calling draw_structure().", + PendingDeprecationWarning, + stacklevel=2, + ) + self.draw_space_kwargs.update(kwargs) - self.space_mesh = self.backend_renderer.draw_structure(**self.space_kwargs) + self.space_mesh = self.backend_renderer.draw_structure(**self.draw_space_kwargs) return self.space_mesh - def draw_agents(self, agent_portrayal: Callable, **kwargs): + def draw_agents(self, agent_portrayal=None, **kwargs): """Draw agents on the space. Args: - agent_portrayal (Callable): Function that takes an agent and returns AgentPortrayalStyle. - **kwargs: Additional keyword arguments for the drawing function. - Checkout respective `SpaceDrawer` class on details how to pass **kwargs. + agent_portrayal: (Deprecated) Function that takes an agent and returns AgentPortrayalStyle. + Use setup_agents() instead. + **kwargs: (Deprecated) Additional keyword arguments for drawing. Returns: The visual representation of the agents. """ - # Store data for internal use - self.agent_portrayal = agent_portrayal - self.agent_kwargs = kwargs + if agent_portrayal is not None: + warnings.warn( + "Passing agent_portrayal to draw_agents() is deprecated. " + "Use setup_agents(agent_portrayal, **kwargs) before calling draw_agents().", + PendingDeprecationWarning, + stacklevel=2, + ) + self.agent_portrayal = agent_portrayal + if kwargs: + warnings.warn( + "Passing kwargs to draw_agents() is deprecated. " + "Use setup_agents(**kwargs) before calling draw_agents().", + PendingDeprecationWarning, + stacklevel=2, + ) + self.draw_agent_kwargs.update(kwargs) # Prepare data for agent plotting arguments = self.backend_renderer.collect_agent_data( - self.space, agent_portrayal, default_size=self.space_drawer.s_default + self.space, self.agent_portrayal, default_size=self.space_drawer.s_default ) arguments = self._map_coordinates(arguments) self.agent_mesh = self.backend_renderer.draw_agents( - arguments, **self.agent_kwargs + arguments, **self.draw_agent_kwargs ) return self.agent_mesh - def draw_propertylayer(self, propertylayer_portrayal: Callable | dict): + def draw_propertylayer(self, propertylayer_portrayal=None): """Draw property layers on the space. Args: - propertylayer_portrayal (Callable | dict): Function that returns PropertyLayerStyle - or dict with portrayal parameters. + propertylayer_portrayal: (Deprecated) Function that takes a property layer and returns PropertyLayerStyle. + Use setup_propertylayer() instead. Returns: The visual representation of the property layers. @@ -216,6 +293,15 @@ def draw_propertylayer(self, propertylayer_portrayal: Callable | dict): Raises: Exception: If no property layers are found on the space. """ + if propertylayer_portrayal is not None: + warnings.warn( + "Passing propertylayer_portrayal to draw_propertylayer() is deprecated. " + "Use setup_propertylayer(propertylayer_portrayal) before calling draw_propertylayer().", + PendingDeprecationWarning, + stacklevel=2, + ) + self.propertylayer_portrayal = propertylayer_portrayal + # Import here to avoid circular imports from mesa.visualization.components import PropertyLayerStyle # noqa: PLC0415 @@ -267,10 +353,10 @@ def style_callable(layer_object): property_layers = self.space._mesa_property_layers # Convert portrayal to callable if needed - if isinstance(propertylayer_portrayal, dict): - self.propertylayer_portrayal = _dict_to_callable(propertylayer_portrayal) - else: - self.propertylayer_portrayal = propertylayer_portrayal + if isinstance(self.propertylayer_portrayal, dict): + self.propertylayer_portrayal = _dict_to_callable( + self.propertylayer_portrayal + ) number_of_propertylayers = sum( [1 for layer in property_layers if layer != "empty"] @@ -283,41 +369,34 @@ def style_callable(layer_object): ) return self.propertylayer_mesh - def render( - self, - agent_portrayal: Callable | None = None, - propertylayer_portrayal: Callable | dict | None = None, - post_process: Callable | None = None, - **kwargs, - ): + def render(self, agent_portrayal=None, propertylayer_portrayal=None, **kwargs): """Render the complete space with structure, agents, and property layers. - It is an all-in-one method that draws everything required therefore eliminates - the need of calling each method separately, but has a drawback, if want to pass - kwargs to customize the drawing, they have to be broken into - space_kwargs and agent_kwargs. - Args: - agent_portrayal (Callable | None): Function that returns AgentPortrayalStyle. - If None, agents won't be drawn. - propertylayer_portrayal (Callable | dict | None): Function that returns - PropertyLayerStyle or dict with portrayal parameters. If None, - property layers won't be drawn. - post_process (Callable | None): Function to apply post-processing to the canvas. - **kwargs: Additional keyword arguments for drawing functions. - * ``space_kwargs`` (dict): Arguments for ``draw_structure()``. - * ``agent_kwargs`` (dict): Arguments for ``draw_agents()``. + agent_portrayal: (Deprecated) Function for agent portrayal. Use setup_agents() instead. + propertylayer_portrayal: (Deprecated) Function for property layer portrayal. Use setup_propertylayer() instead. + **kwargs: (Deprecated) Additional keyword arguments. """ - space_kwargs = kwargs.pop("space_kwargs", {}) - agent_kwargs = kwargs.pop("agent_kwargs", {}) + if agent_portrayal is not None or propertylayer_portrayal is not None or kwargs: + warnings.warn( + "Passing parameters to render() is deprecated. " + "Use setup_structure(), setup_agents(), and setup_propertylayer() before calling render().", + PendingDeprecationWarning, + stacklevel=2, + ) + if agent_portrayal is not None: + self.agent_portrayal = agent_portrayal + if propertylayer_portrayal is not None: + self.propertylayer_portrayal = propertylayer_portrayal + self.draw_space_kwargs.update(kwargs) + if self.space_mesh is None: - self.draw_structure(**space_kwargs) - if self.agent_mesh is None and agent_portrayal is not None: - self.draw_agents(agent_portrayal, **agent_kwargs) - if self.propertylayer_mesh is None and propertylayer_portrayal is not None: - self.draw_propertylayer(propertylayer_portrayal) + self.draw_structure() + if self.agent_mesh is None and self.agent_portrayal is not None: + self.draw_agents() + if self.propertylayer_mesh is None and self.propertylayer_portrayal is not None: + self.draw_propertylayer() - self.post_process_func = post_process return self @property @@ -339,13 +418,11 @@ def canvas(self): prop_base, prop_cbar = self.propertylayer_mesh or (None, None) if self.space_mesh: - structure = self.draw_structure(**self.space_kwargs) + structure = self.draw_structure() if self.agent_mesh: - agents = self.draw_agents(self.agent_portrayal, **self.agent_kwargs) + agents = self.draw_agents() if self.propertylayer_mesh: - prop_base, prop_cbar = self.draw_propertylayer( - self.propertylayer_portrayal - ) + prop_base, prop_cbar = self.draw_propertylayer() spatial_charts_list = [ chart for chart in [structure, prop_base, agents] if chart diff --git a/tests/test_solara_viz_updated.py b/tests/test_solara_viz_updated.py index 03db138fa2d..07aad6af2de 100644 --- a/tests/test_solara_viz_updated.py +++ b/tests/test_solara_viz_updated.py @@ -147,8 +147,12 @@ def agent_portrayal(agent): propertylayer_portrayal = None - renderer = SpaceRenderer(model, backend="matplotlib") - renderer.render(agent_portrayal, propertylayer_portrayal) + renderer = ( + SpaceRenderer(model, backend="matplotlib") + .setup_agents(agent_portrayal) + .setup_propertylayer(propertylayer_portrayal) + .render() + ) # component must be rendered for code to run solara.render( @@ -165,7 +169,7 @@ def agent_portrayal(agent): ) mock_draw_space.assert_called_with(renderer) - mock_draw_agents.assert_called_with(renderer, agent_portrayal) + mock_draw_agents.assert_called_with(renderer) # should not call this method if portrayal is None mock_draw_properties.assert_not_called() @@ -192,8 +196,12 @@ def agent_portrayal(agent): } solara.render(SolaraViz(model, renderer, components=[])) - renderer = SpaceRenderer(model, backend="altair") - renderer.render(agent_portrayal, propertylayer_portrayal) + renderer = ( + SpaceRenderer(model, backend="altair") + .setup_agents(agent_portrayal) + .setup_propertylayer(propertylayer_portrayal) + .render() + ) assert renderer.backend == "altair" assert isinstance( @@ -201,8 +209,8 @@ def agent_portrayal(agent): ) mock_draw_space.assert_called_with(renderer) - mock_draw_agents.assert_called_with(renderer, agent_portrayal) - mock_draw_properties.assert_called_with(renderer, propertylayer_portrayal) + mock_draw_agents.assert_called_with(renderer) + mock_draw_properties.assert_called_with(renderer) mock_draw_space.reset_mock() mock_draw_agents.reset_mock() @@ -210,7 +218,7 @@ def agent_portrayal(agent): solara.render(SolaraViz(model)) - # noting is drawn if renderer is not passed + # nothing is drawn if renderer is not passed assert mock_draw_space.call_count == 0 assert mock_draw_agents.call_count == 0 assert mock_draw_properties.call_count == 0 diff --git a/tests/test_space_renderer.py b/tests/test_space_renderer.py index d74c6468bd8..5072b35143b 100644 --- a/tests/test_space_renderer.py +++ b/tests/test_space_renderer.py @@ -117,10 +117,9 @@ def test_render_calls(): sr.draw_agents = MagicMock() sr.draw_propertylayer = MagicMock() - sr.render( - agent_portrayal=lambda _: {}, - propertylayer_portrayal=lambda _: PropertyLayerStyle(color="red"), - ) + sr.setup_agents(agent_portrayal=lambda _: {}).setup_propertylayer( + propertylayer_portrayal=lambda _: PropertyLayerStyle(color="red") + ).render() sr.draw_structure.assert_called_once() sr.draw_agents.assert_called_once() @@ -139,7 +138,9 @@ def test_no_property_layers(): Exception, match=re.escape("No property layers were found on the space.") ), ): - sr.draw_propertylayer(lambda _: PropertyLayerStyle(color="red")) + sr.setup_propertylayer( + lambda _: PropertyLayerStyle(color="red") + ).draw_propertylayer() def test_post_process():