Skip to content

Commit 11d6517

Browse files
authored
Merge pull request #124 from 00x808080/feature/add-model-constants-file
Add model constants file and standardize model name usage
2 parents 5491ec8 + d0a8753 commit 11d6517

File tree

23 files changed

+96
-71
lines changed

23 files changed

+96
-71
lines changed

skllm/llm/anthropic/completion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Dict, List, Optional
22
from skllm.llm.anthropic.credentials import set_credentials
33
from skllm.utils import retry
4+
from model_constants import ANTHROPIC_CLAUDE_MODEL
45

56
@retry(max_retries=3)
67
def get_chat_completion(
78
messages: List[Dict],
89
key: str,
9-
model: str = "claude-3-haiku-20240307",
10+
model: str = ANTHROPIC_CLAUDE_MODEL,
1011
max_tokens: int = 1000,
1112
temperature: float = 0.0,
1213
system: Optional[str] = None,

skllm/llm/gpt/clients/openai/completion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
set_credentials,
66
)
77
from skllm.utils import retry
8+
from model_constants import OPENAI_GPT_MODEL
89

910

1011
@retry(max_retries=3)
1112
def get_chat_completion(
1213
messages: dict,
1314
key: str,
1415
org: str,
15-
model: str = "gpt-3.5-turbo",
16+
model: str = OPENAI_GPT_MODEL,
1617
api="openai",
1718
json_response=False,
1819
):
@@ -27,7 +28,7 @@ def get_chat_completion(
2728
org : str
2829
The OPEN AI organization ID to use.
2930
model : str, optional
30-
The OPEN AI model to use. Defaults to "gpt-3.5-turbo".
31+
The OPEN AI model to use.
3132
max_retries : int, optional
3233
The maximum number of retries to use. Defaults to 3.
3334
api : str

skllm/llm/gpt/clients/openai/embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from skllm.utils import retry
33
import openai
44
from openai import OpenAI
5+
from model_constants import OPENAI_EMBEDDING_MODEL
56

67

78
@retry(max_retries=3)
89
def get_embedding(
910
text: str,
1011
key: str,
1112
org: str,
12-
model: str = "text-embedding-ada-002",
13+
model: str = OPENAI_EMBEDDING_MODEL,
1314
api: str = "openai"
1415
):
1516
"""
@@ -24,7 +25,7 @@ def get_embedding(
2425
org : str
2526
The OPEN AI organization ID to use.
2627
model : str, optional
27-
The model to use. Defaults to "text-embedding-ada-002".
28+
The model to use.
2829
max_retries : int, optional
2930
The maximum number of retries to use. Defaults to 3.
3031
api: str, optional

skllm/llm/gpt/completion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
)
88
from skllm.llm.gpt.utils import split_to_api_and_model
99
from skllm.config import SKLLMConfig as _Config
10+
from model_constants import OPENAI_GPT_MODEL
1011

1112

1213
def get_chat_completion(
1314
messages: dict,
1415
openai_key: str = None,
1516
openai_org: str = None,
16-
model: str = "gpt-3.5-turbo",
17+
model: str = OPENAI_GPT_MODEL,
1718
json_response: bool = False,
1819
):
1920
"""Gets a chat completion from the OpenAI compatible API."""

skllm/llm/gpt/embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from skllm.llm.gpt.clients.openai.embedding import get_embedding as _oai_get_embedding
22
from skllm.llm.gpt.utils import split_to_api_and_model
3+
from model_constants import OPENAI_EMBEDDING_MODEL
4+
35

46
def get_embedding(
57
text: str,
68
key: str,
79
org: str,
8-
model: str = "text-embedding-ada-002",
10+
model: str = OPENAI_EMBEDDING_MODEL,
911
):
1012
"""
1113
Encodes a string and return the embedding for a string.
@@ -19,7 +21,7 @@ def get_embedding(
1921
org : str
2022
The OPEN AI organization ID to use.
2123
model : str, optional
22-
The model to use. Defaults to "text-embedding-ada-002".
24+
The model to use.
2325
2426
Returns
2527
-------

skllm/llm/gpt/mixin.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
from tqdm import tqdm
2323
import json
24+
from model_constants import OPENAI_GPT_TUNABLE_MODEL
2425

2526

2627
def construct_message(role: str, content: str) -> dict:
@@ -219,24 +220,16 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:
219220

220221
# for now this works only with OpenAI
221222
class GPTTunableMixin(BaseTunableMixin):
222-
_supported_tunable_models = [
223-
"gpt-3.5-turbo-0125",
224-
"gpt-3.5-turbo",
225-
"gpt-4o-mini-2024-07-18",
226-
"gpt-4o-mini",
227-
]
228-
229223
def _build_label(self, label: str):
230224
return json.dumps({"label": label})
231225

232226
def _set_hyperparameters(self, base_model: str, n_epochs: int, custom_suffix: str):
233-
self.base_model = base_model
227+
self.base_model = OPENAI_GPT_TUNABLE_MODEL
234228
self.n_epochs = n_epochs
235229
self.custom_suffix = custom_suffix
236230
if base_model not in self._supported_tunable_models:
237231
raise ValueError(
238232
f"Model {base_model} is not supported. Supported models are"
239-
f" {self._supported_tunable_models}"
240233
)
241234

242235
def _tune(self, X, y):

skllm/model_constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# OpenAI models
2+
OPENAI_GPT_MODEL = "gpt-3.5-turbo"
3+
OPENAI_GPT_TUNABLE_MODEL = "gpt-3.5-turbo-0613"
4+
OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002"
5+
6+
# Anthropic (Claude) models
7+
ANTHROPIC_CLAUDE_MODEL = "claude-3-haiku-20240307"
8+
9+
# Vertex AI models
10+
VERTEX_DEFAULT_MODEL = "text-bison@002"

skllm/models/anthropic/classification/few_shot.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from skllm.models._base.vectorizer import BaseVectorizer
1010
from skllm.memory.base import IndexConstructor
1111
from typing import Optional
12+
from model_constants import ANTHROPIC_CLAUDE_MODEL, OPENAI_EMBEDDING_MODEL
1213

1314

1415
class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin):
1516
"""Few-shot text classifier using Anthropic's Claude API for single-label classification tasks."""
1617

1718
def __init__(
1819
self,
19-
model: str = "claude-3-haiku-20240307",
20+
model: str = ANTHROPIC_CLAUDE_MODEL,
2021
default_label: str = "Random",
2122
prompt_template: Optional[str] = None,
2223
key: Optional[str] = None,
@@ -28,7 +29,7 @@ def __init__(
2829
Parameters
2930
----------
3031
model : str, optional
31-
model to use, by default "claude-3-haiku-20240307"
32+
model to use
3233
default_label : str, optional
3334
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
3435
prompt_template : Optional[str], optional
@@ -52,7 +53,7 @@ class MultiLabelFewShotClaudeClassifier(
5253

5354
def __init__(
5455
self,
55-
model: str = "claude-3-haiku-20240307",
56+
model: str = ANTHROPIC_CLAUDE_MODEL,
5657
default_label: str = "Random",
5758
max_labels: Optional[int] = 5,
5859
prompt_template: Optional[str] = None,
@@ -65,7 +66,7 @@ def __init__(
6566
Parameters
6667
----------
6768
model : str, optional
68-
model to use, by default "claude-3-haiku-20240307"
69+
model to use
6970
default_label : str, optional
7071
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
7172
max_labels : Optional[int], optional
@@ -95,7 +96,7 @@ class DynamicFewShotClaudeClassifier(
9596

9697
def __init__(
9798
self,
98-
model: str = "claude-3-haiku-20240307",
99+
model: str = ANTHROPIC_CLAUDE_MODEL,
99100
default_label: str = "Random",
100101
prompt_template: Optional[str] = None,
101102
key: Optional[str] = None,
@@ -112,7 +113,7 @@ def __init__(
112113
Parameters
113114
----------
114115
model : str, optional
115-
model to use, by default "claude-3-haiku-20240307"
116+
model to use
116117
default_label : str, optional
117118
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
118119
prompt_template : Optional[str], optional
@@ -129,7 +130,7 @@ def __init__(
129130
metric used for similarity search, by default "euclidean"
130131
"""
131132
if vectorizer is None:
132-
vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key)
133+
vectorizer = GPTVectorizer(model=OPENAI_EMBEDDING_MODEL, key=key)
133134
super().__init__(
134135
model=model,
135136
default_label=default_label,

skllm/models/anthropic/classification/zero_shot.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77
from skllm.llm.anthropic.mixin import ClaudeClassifierMixin as _ClaudeClassifierMixin
88
from typing import Optional
9+
from model_constants import ANTHROPIC_CLAUDE_MODEL
910

1011

1112
class ZeroShotClaudeClassifier(
@@ -15,7 +16,7 @@ class ZeroShotClaudeClassifier(
1516

1617
def __init__(
1718
self,
18-
model: str = "claude-3-haiku-20240307",
19+
model: str = ANTHROPIC_CLAUDE_MODEL,
1920
default_label: str = "Random",
2021
prompt_template: Optional[str] = None,
2122
key: Optional[str] = None,
@@ -27,7 +28,7 @@ def __init__(
2728
Parameters
2829
----------
2930
model : str, optional
30-
Model to use, by default "claude-3-haiku-20240307".
31+
Model to use
3132
default_label : str, optional
3233
Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random".
3334
prompt_template : Optional[str], optional
@@ -51,7 +52,7 @@ class CoTClaudeClassifier(
5152

5253
def __init__(
5354
self,
54-
model: str = "claude-3-haiku-20240307",
55+
model: str = ANTHROPIC_CLAUDE_MODEL,
5556
default_label: str = "Random",
5657
prompt_template: Optional[str] = None,
5758
key: Optional[str] = None,
@@ -63,7 +64,7 @@ def __init__(
6364
Parameters
6465
----------
6566
model : str, optional
66-
Model to use, by default "claude-3-haiku-20240307".
67+
Model to use.
6768
default_label : str, optional
6869
Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random".
6970
prompt_template : Optional[str], optional
@@ -87,7 +88,7 @@ class MultiLabelZeroShotClaudeClassifier(
8788

8889
def __init__(
8990
self,
90-
model: str = "claude-3-haiku-20240307",
91+
model: str = ANTHROPIC_CLAUDE_MODEL,
9192
default_label: str = "Random",
9293
max_labels: Optional[int] = 5,
9394
prompt_template: Optional[str] = None,
@@ -100,7 +101,7 @@ def __init__(
100101
Parameters
101102
----------
102103
model : str, optional
103-
Model to use, by default "claude-3-haiku-20240307".
104+
Model to use.
104105
default_label : str, optional
105106
Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random".
106107
max_labels : Optional[int], optional

skllm/models/anthropic/tagging/ner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from skllm.models._base.tagger import ExplainableNER as _ExplainableNER
22
from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin
33
from typing import Optional, Dict
4+
from model_constants import ANTHROPIC_CLAUDE_MODEL
45

56

67
class AnthropicExplainableNER(_ExplainableNER, _ClaudeTextCompletionMixin):
@@ -11,7 +12,7 @@ def __init__(
1112
entities: Dict[str, str],
1213
display_predictions: bool = False,
1314
sparse_output: bool = True,
14-
model: str = "claude-3-haiku-20240307",
15+
model: str = ANTHROPIC_CLAUDE_MODEL,
1516
key: Optional[str] = None,
1617
num_workers: int = 1,
1718
) -> None:
@@ -27,7 +28,7 @@ def __init__(
2728
sparse_output : bool, optional
2829
whether to generate a sparse representation of the predictions, by default True
2930
model : str, optional
30-
model to use, by default "claude-3-haiku-20240307"
31+
model to use
3132
key : Optional[str], optional
3233
estimator-specific API key; if None, retrieved from the global config
3334
num_workers : int, optional

0 commit comments

Comments
 (0)