1010
1111import openai
1212import pytest
13+ from pytest_mock import MockerFixture
1314
1415sys .path .insert (
1516 0 , os .path .abspath ("../.." )
@@ -536,4 +537,47 @@ async def test_simple_aembedding():
536537 "object" : "embedding" ,
537538 "embedding" : [0.1 , 0.2 , 0.3 ],
538539 "index" : 1 ,
539- }
540+ }
541+
542+
543+ def test_custom_llm_provider_entrypoint (mocker : MockerFixture ):
544+ # This test mocks the use of entry-points in pyproject.toml:
545+ # [project.entry-point.litellm]
546+ # custom_llm = <module>:MyCustomLLM
547+ # another-custom-llm = <module>:AnotherCustomLLM
548+
549+ from litellm .utils import custom_llm_setup
550+
551+ class AnotherCustomLLM (CustomLLM ):
552+ pass
553+
554+ providers = {
555+ "custom_llm" : MyCustomLLM ,
556+ "another-custom-llm" : AnotherCustomLLM
557+ }
558+
559+ def load (self ):
560+ return providers [self .name ]
561+
562+ mocker .patch ("importlib.metadata.EntryPoint.load" , load )
563+ from importlib .metadata import EntryPoints , EntryPoint
564+
565+ entry_points = EntryPoints ([
566+ EntryPoint (group = "litellm" , name = "custom_llm" , value = "package.module:MyCustomLLM" ),
567+ EntryPoint (group = "litellm" , name = "another-custom-llm" , value = "package.module:AnotherCustomLLM" ),
568+ ])
569+ mocked = mocker .patch ("litellm.utils.entry_points" )
570+ mocked .return_value = entry_points
571+
572+ assert litellm .custom_provider_map == []
573+ assert litellm ._custom_providers == []
574+
575+ custom_llm_setup ()
576+
577+ assert litellm ._custom_providers == ['custom_llm' , 'another-custom-llm' ]
578+
579+ assert litellm .custom_provider_map [0 ]["provider" ] == "custom_llm"
580+ assert isinstance (litellm .custom_provider_map [0 ]["custom_handler" ], CustomLLM )
581+
582+ assert litellm .custom_provider_map [1 ]["provider" ] == "another-custom-llm"
583+ assert isinstance (litellm .custom_provider_map [1 ]["custom_handler" ], AnotherCustomLLM )
0 commit comments