Skip to content

Commit 559ae96

Browse files
Python entry-point for CustomLLM subclasses (#15881)
* load entrypoints * mock loading entry-point in pyproject.toml * simpler group name * create CustomLLM subclass instance after load
1 parent e6a7cae commit 559ae96

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

litellm/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from dataclasses import dataclass, field
3737
from functools import lru_cache, wraps
3838
from importlib import resources
39+
from importlib.metadata import entry_points
3940
from inspect import iscoroutine
4041
from io import StringIO
4142
from os.path import abspath, dirname, join
@@ -385,10 +386,22 @@ def print_verbose(
385386

386387
####### CLIENT ###################
387388
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
389+
def load_custom_provider_entrypoints():
390+
found_entry_points = tuple(entry_points().select(group="litellm")) # type: ignore
391+
for entry_point in found_entry_points:
392+
# types are ignored because of circular dependency issues importing CustomLLM and CustomLLMItem
393+
HandlerClass = entry_point.load()
394+
handler = HandlerClass()
395+
provider = {"provider": entry_point.name, "custom_handler": handler}
396+
litellm.custom_provider_map.append(provider) # type: ignore
397+
398+
388399
def custom_llm_setup():
389400
"""
390401
Add custom_llm provider to provider list
391402
"""
403+
load_custom_provider_entrypoints()
404+
392405
for custom_llm in litellm.custom_provider_map:
393406
if custom_llm["provider"] not in litellm.provider_list:
394407
litellm.provider_list.append(custom_llm["provider"])

tests/local_testing/test_custom_llm.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import openai
1212
import pytest
13+
from pytest_mock import MockerFixture
1314

1415
sys.path.insert(
1516
0, os.path.abspath("../..")
@@ -536,4 +537,47 @@ async def test_simple_aembedding():
536537
"object": "embedding",
537538
"embedding": [0.1, 0.2, 0.3],
538539
"index": 1,
539-
}
540+
}
541+
542+
543+
def test_custom_llm_provider_entrypoint(mocker: MockerFixture):
544+
# This test mocks the use of entry-points in pyproject.toml:
545+
# [project.entry-point.litellm]
546+
# custom_llm = <module>:MyCustomLLM
547+
# another-custom-llm = <module>:AnotherCustomLLM
548+
549+
from litellm.utils import custom_llm_setup
550+
551+
class AnotherCustomLLM(CustomLLM):
552+
pass
553+
554+
providers = {
555+
"custom_llm": MyCustomLLM,
556+
"another-custom-llm": AnotherCustomLLM
557+
}
558+
559+
def load(self):
560+
return providers[self.name]
561+
562+
mocker.patch("importlib.metadata.EntryPoint.load", load)
563+
from importlib.metadata import EntryPoints, EntryPoint
564+
565+
entry_points = EntryPoints([
566+
EntryPoint(group="litellm", name="custom_llm", value="package.module:MyCustomLLM"),
567+
EntryPoint(group="litellm", name="another-custom-llm", value="package.module:AnotherCustomLLM"),
568+
])
569+
mocked = mocker.patch("litellm.utils.entry_points")
570+
mocked.return_value = entry_points
571+
572+
assert litellm.custom_provider_map == []
573+
assert litellm._custom_providers == []
574+
575+
custom_llm_setup()
576+
577+
assert litellm._custom_providers == ['custom_llm', 'another-custom-llm']
578+
579+
assert litellm.custom_provider_map[0]["provider"] == "custom_llm"
580+
assert isinstance(litellm.custom_provider_map[0]["custom_handler"], CustomLLM)
581+
582+
assert litellm.custom_provider_map[1]["provider"] == "another-custom-llm"
583+
assert isinstance(litellm.custom_provider_map[1]["custom_handler"], AnotherCustomLLM)

0 commit comments

Comments
 (0)