Skip to content

Commit b0dd7a6

Browse files
committed
sys_prompt
1 parent bbc3cc7 commit b0dd7a6

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

mesa_llm/llm_agent.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33

44
class LLMAgent:
5-
def __init__(self, api_key: str, model: str = "openai/gpt-4o"):
5+
def __init__(self, api_key: str, model: str = "openai/gpt-4o", system_prompt: str|None = None):
66
"""Initialize the LLMAgent with a ModuleLLM instance."""
7-
self.llm = ModuleLLM(api_key=api_key, model=model)
7+
self.llm = ModuleLLM(api_key=api_key, model=model, system_prompt=system_prompt)
88

9-
def set_llm(self, api_key: str, model: str = "openai/gpt-4o") -> None:
10-
"""Replace the current LLM with a new configuration."""
11-
self.llm = ModuleLLM(api_key=api_key, model=model)
9+
def set_model(self, api_key: str, model: str = "openai/gpt-4o") -> None:
10+
"""Set the model of the LLM."""
11+
self.llm = ModuleLLM(api_key=api_key, model=model, system_prompt=self.llm.system_prompt)
12+
13+
def set_system_prompt(self, system_prompt: str) -> None:
14+
"""Set the system prompt for the LLM."""
15+
self.llm= ModuleLLM(api_key=self.llm.api_key, model=self.llm.model, system_prompt=system_prompt)

mesa_llm/module_llm.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,30 @@ class ModuleLLM:
1010
Note : Currently supports OpenAI, Anthropic, xAI, Huggingface, Ollama, OpenRouter, NovitaAI
1111
"""
1212

13-
def __init__(self, api_key: str, model: str):
13+
def __init__(self, api_key: str, model: str, system_prompt: str | None = None):
1414
"""
1515
Initialize the LLM module
1616
1717
Args:
1818
api_key: The API key for the LLM provider
19-
model: The model to use for the LLM
19+
model: The model to use for the LLM in the format of {provider}/{model}
20+
system_prompt: The system prompt to use for the LLM
2021
"""
2122
self.api_key = api_key
22-
provider = model.split("/")[0].upper()
23-
23+
self.model=model
24+
self.system_prompt = system_prompt
25+
provider = self.model.split("/")[0].upper()
2426
os.environ[f"{provider}_API_KEY"] = self.api_key
2527

26-
def generate(self, prompt: str, system_prompt: str | None = None) -> str:
27-
"""
28-
Generate a response from the LLM
29-
"""
30-
if system_prompt:
28+
def generate(self, prompt: str) -> str:
29+
if self.system_prompt:
3130
messages = [
32-
{"role": "system", "content": system_prompt},
31+
{"role": "system", "content": self.system_prompt},
3332
{"role": "user", "content": prompt},
3433
]
3534
else:
3635
messages = [{"role": "user", "content": prompt}]
37-
38-
response = completion(model="openai/gpt-4o", messages=messages)
36+
response = completion(model=self.model, messages=messages)
3937
return response
4038

4139

0 commit comments

Comments
 (0)