diff --git a/.gitignore b/.gitignore index cae9c06..084fb39 100644 --- a/.gitignore +++ b/.gitignore @@ -339,4 +339,5 @@ pip-log.txt #GSoC documentation final_architecture.png -GSoC_2025.md \ No newline at end of file +GSoC_2025.md +/architectures \ No newline at end of file diff --git a/mesa_llm/memory/episodic_memory.py b/mesa_llm/memory/episodic_memory.py index 142defb..0019d19 100644 --- a/mesa_llm/memory/episodic_memory.py +++ b/mesa_llm/memory/episodic_memory.py @@ -131,3 +131,14 @@ def get_communication_history(self) -> str: if "message" in entry.content ] ) + + def process_step(self, pre_step: bool = False): + """ + Process the step of the agent : + - Add the new entry to the memory + - Display the new entry + """ + if pre_step: + self.add_to_memory(type="observation", content=self.step_content) + self.step_content = {} + return diff --git a/mesa_llm/memory/memory.py b/mesa_llm/memory/memory.py index a8683de..3d25bbb 100644 --- a/mesa_llm/memory/memory.py +++ b/mesa_llm/memory/memory.py @@ -111,6 +111,15 @@ def get_communication_history(self) -> str: Get the communication history in a format that can be used for reasoning """ + @abstractmethod + def process_step(self, pre_step: bool = False): + r""" + A function that is called before and after the step of the agent is called. + It is implemented to ensure that the memory is up to date when the agent is starting a new step. + + /!\ If you consider that you do not need this function, you can write "pass" in its implementation. + """ + def add_to_memory(self, type: str, content: dict): """ Add a new entry to the memory diff --git a/mesa_llm/memory/st_lt_memory.py b/mesa_llm/memory/st_lt_memory.py index 385c342..4c4fcd0 100644 --- a/mesa_llm/memory/st_lt_memory.py +++ b/mesa_llm/memory/st_lt_memory.py @@ -101,8 +101,8 @@ def process_step(self, pre_step: bool = False): return elif not self.short_term_memory[-1].content.get("step", None): - pre_step = self.short_term_memory.pop() - self.step_content.update(pre_step.content) + pre_step_entry = self.short_term_memory.pop() + self.step_content.update(pre_step_entry.content) new_entry = MemoryEntry( agent=self.agent, content=self.step_content, @@ -112,21 +112,21 @@ def process_step(self, pre_step: bool = False): self.short_term_memory.append(new_entry) self.step_content = {} - # Consolidate memory if the short term memory is over capacity - if ( - len(self.short_term_memory) - > self.capacity + (self.consolidation_capacity or 0) - and self.consolidation_capacity - ): - self.short_term_memory.popleft() - self._update_long_term_memory() - - elif len(self.short_term_memory) > self.capacity: - self.short_term_memory.popleft() - - # Display the new entry - if self.display: - new_entry.display() + # Consolidate memory if the short term memory is over capacity + if ( + len(self.short_term_memory) + > self.capacity + (self.consolidation_capacity or 0) + and self.consolidation_capacity + ): + self.short_term_memory.popleft() + self._update_long_term_memory() + + elif len(self.short_term_memory) > self.capacity: + self.short_term_memory.popleft() + + # Display the new entry + if self.display: + new_entry.display() def format_long_term(self) -> str: """ diff --git a/tests/test_memory/test_memory_parent.py b/tests/test_memory/test_memory_parent.py index 6848e3b..b27def8 100644 --- a/tests/test_memory/test_memory_parent.py +++ b/tests/test_memory/test_memory_parent.py @@ -48,6 +48,12 @@ def get_prompt_ready(self) -> str: def get_communication_history(self) -> str: return "" + def process_step(self, pre_step: bool = False): + """ + Mock implementation of process_step for testing purposes. + Since this is a test mock, we can use a simple pass implementation. + """ + class TestMemoryParent: """Test the Memory class"""