Skip to content

Commit 0595158

Browse files
Merge pull request #44 from vbossica/reference_provider_names
Reference providers by their enum values
2 parents 82e3780 + a9e60c2 commit 0595158

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

python_gpt_po/models/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,10 @@ class ModelProvider(Enum):
1010
OPENAI = "openai"
1111
ANTHROPIC = "anthropic"
1212
DEEPSEEK = "deepseek"
13+
14+
15+
ModelProviderList = [
16+
ModelProvider.OPENAI.value,
17+
ModelProvider.ANTHROPIC.value,
18+
ModelProvider.DEEPSEEK.value
19+
]

python_gpt_po/models/provider_clients.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from anthropic import Anthropic
88
from openai import OpenAI
99

10+
from .enums import ModelProvider
11+
1012

1113
class ProviderClients:
1214
"""Class to store API clients for various providers."""
@@ -23,11 +25,11 @@ def initialize_clients(self, api_keys: Dict[str, str]):
2325
Args:
2426
api_keys (Dict[str, str]): Dictionary of provider names to API keys
2527
"""
26-
if api_keys.get("openai"):
27-
self.openai_client = OpenAI(api_key=api_keys["openai"])
28+
if api_keys.get(ModelProvider.OPENAI.value):
29+
self.openai_client = OpenAI(api_key=api_keys[ModelProvider.OPENAI.value])
2830

29-
if api_keys.get("anthropic"):
30-
self.anthropic_client = Anthropic(api_key=api_keys["anthropic"])
31+
if api_keys.get(ModelProvider.ANTHROPIC.value):
32+
self.anthropic_client = Anthropic(api_key=api_keys[ModelProvider.ANTHROPIC.value])
3133

32-
if api_keys.get("deepseek"):
33-
self.deepseek_api_key = api_keys["deepseek"]
34+
if api_keys.get(ModelProvider.DEEPSEEK.value):
35+
self.deepseek_api_key = api_keys[ModelProvider.DEEPSEEK.value]

python_gpt_po/utils/cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
from typing import Dict, List, Optional
1111

12-
from ..models.enums import ModelProvider
12+
from ..models.enums import ModelProvider, ModelProviderList
1313
from .helpers import get_version
1414

1515

@@ -99,7 +99,7 @@ def parse_args():
9999
# Provider settings
100100
provider_group.add_argument(
101101
"--provider",
102-
choices=["openai", "anthropic", "deepseek"],
102+
choices=ModelProviderList,
103103
help="AI provider to use (default: first provider with available API key)"
104104
)
105105
provider_group.add_argument(
@@ -248,9 +248,9 @@ def get_api_keys_from_args(args) -> Dict[str, str]:
248248
Dict[str, str]: Dictionary of provider names to API keys
249249
"""
250250
return {
251-
"openai": args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", ""),
252-
"anthropic": args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", ""),
253-
"deepseek": args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "")
251+
ModelProvider.OPENAI.value: args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", ""),
252+
ModelProvider.ANTHROPIC.value: args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", ""),
253+
ModelProvider.DEEPSEEK.value: args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "")
254254
}
255255

256256

@@ -264,7 +264,7 @@ def auto_select_provider(api_keys: Dict[str, str]) -> Optional[ModelProvider]:
264264
Returns:
265265
Optional[ModelProvider]: The auto-selected provider or None if no keys available
266266
"""
267-
for provider_name in ["openai", "anthropic", "deepseek"]:
267+
for provider_name in ModelProviderList:
268268
if api_keys.get(provider_name):
269269
provider = ModelProvider(provider_name)
270270
logging.info("Auto-selected provider: %s (based on available API key)", provider_name)

0 commit comments

Comments
 (0)