|
| 1 | +import inspect |
| 2 | +from collections.abc import Callable |
| 3 | + |
| 4 | +""" |
| 5 | + The user can use insntaces of ToolManager to regsiter functions as tools through the decorator. |
| 6 | + The user can also use the ToolManager instance to get the schema of the tools, call a tool with validated arguments, and check if a tool is registered. |
| 7 | + Moreover, the user can group like tools together by creating a new ToolManager instance and registering the tools to it. |
| 8 | + So if agent A requires tools A1, A2, and A3, and agent B requires tools B1, B2, and B3, the user can create two ToolManager instances: tool_manager_A and tool_manager_B. |
| 9 | +""" |
| 10 | + |
| 11 | +class ToolManager: |
| 12 | + def __init__(self): |
| 13 | + self.tools: dict[str, Callable] = {} |
| 14 | + |
| 15 | + def register(self, fn: Callable): |
| 16 | + """Register a tool function by name""" |
| 17 | + name = fn.__name__ |
| 18 | + self.tools[name] = fn #storing the name & function pair as a dicitonary |
| 19 | + |
| 20 | + def get_schema(self) -> list[dict]: |
| 21 | + """Return schema in the liteLLM format""" |
| 22 | + #we need to convert the function signature from python to a JSON schema |
| 23 | + py_to_json_type = { |
| 24 | + str: "string", |
| 25 | + int: "integer", |
| 26 | + float: "number", |
| 27 | + bool: "boolean", |
| 28 | + list: "array", |
| 29 | + dict: "object" |
| 30 | + } |
| 31 | + |
| 32 | + schema = [] |
| 33 | + for name, fn in self.tools.items(): |
| 34 | + sig = inspect.signature(fn) |
| 35 | + properties = {} |
| 36 | + required = [] |
| 37 | + |
| 38 | + for param in sig.parameters.values(): |
| 39 | + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): |
| 40 | + # skip *args and **kwargs |
| 41 | + continue |
| 42 | + param_schema = { |
| 43 | + "description": f"{param.name} parameter" |
| 44 | + } |
| 45 | + |
| 46 | + # If type annotation is available |
| 47 | + if param.annotation != inspect.Parameter.empty: |
| 48 | + annotation = param.annotation |
| 49 | + |
| 50 | + json_type = py_to_json_type.get(annotation) |
| 51 | + |
| 52 | + if json_type: |
| 53 | + param_schema["type"] = json_type |
| 54 | + else: |
| 55 | + # fallback: allow any type |
| 56 | + param_schema["type"] = ["string", "number", "boolean", "object", "array", "null"] |
| 57 | + else: |
| 58 | + # No annotation so fallback |
| 59 | + param_schema["type"] = ["string", "number", "boolean", "object", "array", "null"] |
| 60 | + |
| 61 | + properties[param.name] = param_schema |
| 62 | + |
| 63 | + if param.default == inspect.Parameter.empty: |
| 64 | + required.append(param.name) |
| 65 | + |
| 66 | + schema.append({ |
| 67 | + "type": "function", |
| 68 | + "function": { |
| 69 | + "name": name, |
| 70 | + "description": fn.__doc__ or "", |
| 71 | + "parameters": { |
| 72 | + "type": "object", |
| 73 | + "properties": properties, |
| 74 | + "required": required |
| 75 | + } |
| 76 | + } |
| 77 | + }) |
| 78 | + return schema #incase the user wants to change the something like say the parameter description, they will have to get the schema and edit it manually |
| 79 | + |
| 80 | + def call(self, name: str, arguments: dict) -> str: |
| 81 | + """Call a registered tool with validated args""" |
| 82 | + if name not in self.tools: |
| 83 | + raise ValueError(f"Tool '{name}' not found") |
| 84 | + return self.tools[name](**arguments) |
| 85 | + |
| 86 | + def has_tool(self, name: str) -> bool: |
| 87 | + return name in self.tools |
| 88 | + |
| 89 | + |
| 90 | + |
0 commit comments