Skip to content

Commit eac6d47

Browse files
authored
add a class for managing tools (#3)
1 parent 25e44e2 commit eac6d47

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

mesa_llm/tools/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tool_manager import ToolManager
2+
3+
__all__ = ["ToolManager"]

mesa_llm/tools/inbuilt_tools.py

Whitespace-only changes.

mesa_llm/tools/tool_manager.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import inspect
2+
from collections.abc import Callable
3+
4+
"""
5+
The user can use insntaces of ToolManager to regsiter functions as tools through the decorator.
6+
The user can also use the ToolManager instance to get the schema of the tools, call a tool with validated arguments, and check if a tool is registered.
7+
Moreover, the user can group like tools together by creating a new ToolManager instance and registering the tools to it.
8+
So if agent A requires tools A1, A2, and A3, and agent B requires tools B1, B2, and B3, the user can create two ToolManager instances: tool_manager_A and tool_manager_B.
9+
"""
10+
11+
class ToolManager:
12+
def __init__(self):
13+
self.tools: dict[str, Callable] = {}
14+
15+
def register(self, fn: Callable):
16+
"""Register a tool function by name"""
17+
name = fn.__name__
18+
self.tools[name] = fn #storing the name & function pair as a dicitonary
19+
20+
def get_schema(self) -> list[dict]:
21+
"""Return schema in the liteLLM format"""
22+
#we need to convert the function signature from python to a JSON schema
23+
py_to_json_type = {
24+
str: "string",
25+
int: "integer",
26+
float: "number",
27+
bool: "boolean",
28+
list: "array",
29+
dict: "object"
30+
}
31+
32+
schema = []
33+
for name, fn in self.tools.items():
34+
sig = inspect.signature(fn)
35+
properties = {}
36+
required = []
37+
38+
for param in sig.parameters.values():
39+
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
40+
# skip *args and **kwargs
41+
continue
42+
param_schema = {
43+
"description": f"{param.name} parameter"
44+
}
45+
46+
# If type annotation is available
47+
if param.annotation != inspect.Parameter.empty:
48+
annotation = param.annotation
49+
50+
json_type = py_to_json_type.get(annotation)
51+
52+
if json_type:
53+
param_schema["type"] = json_type
54+
else:
55+
# fallback: allow any type
56+
param_schema["type"] = ["string", "number", "boolean", "object", "array", "null"]
57+
else:
58+
# No annotation so fallback
59+
param_schema["type"] = ["string", "number", "boolean", "object", "array", "null"]
60+
61+
properties[param.name] = param_schema
62+
63+
if param.default == inspect.Parameter.empty:
64+
required.append(param.name)
65+
66+
schema.append({
67+
"type": "function",
68+
"function": {
69+
"name": name,
70+
"description": fn.__doc__ or "",
71+
"parameters": {
72+
"type": "object",
73+
"properties": properties,
74+
"required": required
75+
}
76+
}
77+
})
78+
return schema #incase the user wants to change the something like say the parameter description, they will have to get the schema and edit it manually
79+
80+
def call(self, name: str, arguments: dict) -> str:
81+
"""Call a registered tool with validated args"""
82+
if name not in self.tools:
83+
raise ValueError(f"Tool '{name}' not found")
84+
return self.tools[name](**arguments)
85+
86+
def has_tool(self, name: str) -> bool:
87+
return name in self.tools
88+
89+
90+

0 commit comments

Comments
 (0)