Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 89 additions & 23 deletions mesa_frames/concrete/datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ def step(self):
from collections.abc import Callable
from mesa_frames import Model
from psycopg2.extensions import connection
import logging


class DataCollector(AbstractDataCollector):
def __init__(
self,
model: Model,
model_reporters: dict[str, Callable] | None = None,
agent_reporters: dict[str, str | Callable] | None = None,
agent_reporters: dict[str, str | Callable] | None = None, # <-- ALLOWS CALLABLE
trigger: Callable[[Any], bool] | None = None,
reset_memory: bool = True,
storage: Literal[
Expand All @@ -91,7 +92,10 @@ def __init__(
model_reporters : dict[str, Callable] | None
Functions to collect data at the model level.
agent_reporters : dict[str, str | Callable] | None
Attributes or functions to collect data at the agent level.
(MODIFIED) A dictionary mapping new column names to existing
column names (str) or callables. Callables are not currently
processed by the agent data collector but are allowed for API compatibility.
Example: {"agent_wealth": "wealth", "age_in_years": "age"}
trigger : Callable[[Any], bool] | None
A function(model) -> bool that determines whether to collect data.
reset_memory : bool
Expand All @@ -105,6 +109,18 @@ def __init__(
max_worker : int
Maximum number of worker threads used for flushing collected data asynchronously
"""
if agent_reporters:
for key, value in agent_reporters.items():
if not isinstance(key, str):
raise TypeError(
f"Agent reporter keys must be strings (the final column name), not a {type(key)}."
)
if not (isinstance(value, str) or callable(value)):
raise TypeError(
f"Agent reporter for '{key}' must be either a string (the source column name) "
f"or a callable (function taking an agent and returning a value), not a {type(value)}."
)

super().__init__(
model=model,
model_reporters=model_reporters,
Expand Down Expand Up @@ -174,25 +190,71 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
"""
Collect agent-level data using the agent_reporters.

Constructs a LazyFrame with one column per reporter and
includes `step` and `seed` metadata. Appends it to internal storage.
This method iterates through all AgentSets in the model, selects the
`unique_id` and the requested reporter columns from each AgentSet's
DataFrame, adds an `agent_type` column, and concatenates them
into a single "long" format LazyFrame.
"""
agent_data_dict = {}
for col_name, reporter in self._agent_reporters.items():
if isinstance(reporter, str):
for k, v in self._model.sets[reporter].items():
agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v
else:
agent_data_dict[col_name] = reporter(self._model)
agent_lazy_frame = pl.LazyFrame(agent_data_dict)
agent_lazy_frame = agent_lazy_frame.with_columns(
all_agent_frames = []
reporter_map = self._agent_reporters

try:
agent_sets_list = self._model.sets._agentsets
except AttributeError:
logging.error(
"DataCollector could not find '_agentsets' attribute on model.sets. "
"Agent data collection will be skipped."
)
return

for agent_set in agent_sets_list:
if not hasattr(agent_set, "df"):
logging.warning(
f"AgentSet {agent_set.__class__.__name__} has no 'df' attribute. Skipping."
)
continue

agent_df = agent_set.df.lazy()
agent_type = agent_set.__class__.__name__
available_cols = agent_df.columns

if "unique_id" not in available_cols:
logging.warning(
f"AgentSet {agent_type} 'df' has no 'unique_id' column. Skipping."
)
continue

cols_to_select = [pl.col("unique_id")]

for final_name, source_col in reporter_map.items():
if source_col in available_cols:
## Add the column, aliasing it if the key is different
cols_to_select.append(pl.col(source_col).alias(final_name))

## Only proceed if we have more than just unique_id
if len(cols_to_select) > 1:
set_frame = agent_df.select(cols_to_select)
## Add the agent_type column
set_frame = set_frame.with_columns(
pl.lit(agent_type).alias("agent_type")
)
all_agent_frames.append(set_frame)

if not all_agent_frames:
return

## Combine all agent set DataFrames into one
final_agent_frame = pl.concat(all_agent_frames, how="diagonal_relaxed")

## Add metadata and append
final_agent_frame = final_agent_frame.with_columns(
[
pl.lit(current_model_step).alias("step"),
pl.lit(str(self.seed)).alias("seed"),
pl.lit(batch_id).alias("batch"),
]
)
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
self._frames.append(("agent", current_model_step, batch_id, final_agent_frame))

@property
def data(self) -> dict[str, pl.DataFrame]:
Expand Down Expand Up @@ -461,13 +523,20 @@ def _validate_reporter_table_columns(
If any expected columns are missing from the table.
"""
expected_columns = set()

## Add columns required for the new long agent format
if table_name == "agent_data":
expected_columns.add("unique_id")
expected_columns.add("agent_type")

## Add all keys from the reporter dict
for col_name, required_column in reporter.items():
if isinstance(required_column, str):
for k, v in self._model.sets[required_column].items():
expected_columns.add(
(col_name + "_" + str(k.__class__.__name__)).lower()
)
if table_name == "agent_data":
if isinstance(required_column, str):
expected_columns.add(col_name.lower())
## Callables are not supported for agents
else:
## For model, all reporters are callable
expected_columns.add(col_name.lower())

query = f"""
Expand All @@ -484,10 +553,7 @@ def _validate_reporter_table_columns(

existing_columns = {row[0] for row in result}
missing_columns = expected_columns - existing_columns
required_columns = {
"step": "Integer",
"seed": "Varchar",
}
required_columns = {"step": "Integer", "seed": "Varchar", "batch": "Integer"}

missing_required = {
col: col_type
Expand Down
Loading
Loading