Skip to content

Commit 6dce35d

Browse files
OKUA1iryna-kondr
andcommitted
tunable models
Co-authored-by: Iryna Kondrashchenko <[email protected]>
1 parent 5fbecdc commit 6dce35d

File tree

14 files changed

+221
-37
lines changed

14 files changed

+221
-37
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,3 @@ repos:
5656
- id: docformatter
5757
additional_dependencies: [tomli]
5858
args: ["--in-place", "--config", "pyproject.toml"]
59-
# Python tool for docstring coverage
60-
- repo: https://github.com/econchick/interrogate
61-
rev: 1.5.0
62-
hooks:
63-
- id: interrogate
64-
args:
65-
[
66-
"--config",
67-
"pyproject.toml",
68-
"--generate-badge",
69-
".github/assets/badges",
70-
"--badge-style",
71-
"flat",
72-
]
73-
pass_filenames: false

skllm/llm/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ class BaseTunableMixin(ABC):
2727
@abstractmethod
2828
def _tune(self, X: Any, y: Any):
2929
pass
30+
31+
@abstractmethod
32+
def _set_hyperparameters(self, **kwargs):
33+
pass

skllm/llm/gpt/clients/openai/completion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def get_chat_completion(
1414
org: str,
1515
model: str = "gpt-3.5-turbo",
1616
api="openai",
17+
json_response=False,
1718
):
1819
"""Gets a chat completion from the OpenAI API.
1920
@@ -38,12 +39,13 @@ def get_chat_completion(
3839
"""
3940
if api == "openai":
4041
client = set_credentials(key, org)
41-
model_dict = {"model": model, "response_format": {"type": "json_object"}}
4242
elif api == "azure":
4343
client = set_azure_credentials(key, org)
44-
model_dict = {"model": model}
4544
else:
4645
raise ValueError("Invalid API")
46+
model_dict = {"model": model}
47+
if json_response and model in ["gpt-4-1106-preview", "gpt-3.5-turbo-1106"]:
48+
model_dict["response_format"] = {"type": "json_object"}
4749
completion = client.chat.completions.create(
4850
temperature=0.0, messages=messages, **model_dict
4951
)

skllm/llm/gpt/clients/openai/credentials.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from time import sleep
44
from openai import OpenAI, AzureOpenAI
55

6+
67
def set_credentials(key: str, org: str) -> None:
78
"""Set the OpenAI key and organization.
89
@@ -16,6 +17,7 @@ def set_credentials(key: str, org: str) -> None:
1617
client = OpenAI(api_key=key, organization=org)
1718
return client
1819

20+
1921
def set_azure_credentials(key: str, org: str) -> None:
2022
"""Sets OpenAI credentials for Azure.
2123
@@ -26,6 +28,10 @@ def set_azure_credentials(key: str, org: str) -> None:
2628
org : str
2729
The OpenAI (Azure) organization ID to use.
2830
"""
29-
client = AzureOpenAI(api_key=key, organization=org, api_version=_Config.get_azure_api_version(), azure_endpoint = _Config.get_azure_api_base())
31+
client = AzureOpenAI(
32+
api_key=key,
33+
organization=org,
34+
api_version=_Config.get_azure_api_version(),
35+
azure_endpoint=_Config.get_azure_api_base(),
36+
)
3037
return client
31-

skllm/llm/gpt/completion.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def get_chat_completion(
1111
openai_key: str = None,
1212
openai_org: str = None,
1313
model: str = "gpt-3.5-turbo",
14+
json_response: bool = False,
1415
):
1516
"""Gets a chat completion from the OpenAI compatible API."""
1617
if model.startswith("gpt4all::"):
@@ -20,5 +21,10 @@ def get_chat_completion(
2021
if api == "azure":
2122
model = model[7:]
2223
return _oai_get_chat_completion(
23-
messages, openai_key, openai_org, model, api=api
24+
messages,
25+
openai_key,
26+
openai_org,
27+
model,
28+
api=api,
29+
json_response=json_response,
2430
)

skllm/llm/gpt/mixin.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class GPTMixin:
6060
A mixin class that provides OpenAI key and organization to other classes.
6161
"""
6262

63+
_prefer_json_output = False
64+
6365
def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> None:
6466
"""
6567
Set the OpenAI key and organization.
@@ -132,12 +134,18 @@ def _get_chat_completion(
132134
for message in messages:
133135
msgs.append(construct_message(message["role"], message["content"]))
134136
completion = get_chat_completion(
135-
msgs, self._get_openai_key(), self._get_openai_org(), model
137+
msgs,
138+
self._get_openai_key(),
139+
self._get_openai_org(),
140+
model,
141+
json_response=self._prefer_json_output,
136142
)
137143
return completion
138144

139145

140146
class GPTClassifierMixin(GPTTextCompletionMixin, BaseClassifierMixin):
147+
_prefer_json_output = True
148+
141149
def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> Any:
142150
"""Extracts the label from a completion.
143151
@@ -205,6 +213,16 @@ class GPTTunableMixin(BaseTunableMixin):
205213
def _build_label(self, label: str):
206214
return json.dumps({"label": label})
207215

216+
def _set_hyperparameters(self, base_model: str, n_epochs: int, custom_suffix: str):
217+
self.base_model = base_model
218+
self.n_epochs = n_epochs
219+
self.custom_suffix = custom_suffix
220+
if base_model not in self._supported_tunable_models:
221+
raise ValueError(
222+
f"Model {base_model} is not supported. Supported models are"
223+
f" {self._supported_tunable_models}"
224+
)
225+
208226
def _tune(self, X, y):
209227
if self.base_model.startswith(("azure::", "gpt4all")):
210228
raise ValueError(

skllm/llm/vertex/mixin.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
BaseTextCompletionMixin,
77
BaseTunableMixin,
88
)
9+
from skllm.llm.vertex.tuning import tune
910
from skllm.llm.vertex.completion import get_completion_chat_mode, get_completion
1011
from skllm.utils import extract_json_key
1112
import numpy as np
1213
from tqdm import tqdm
13-
import json
14+
import pandas as pd
1415

1516

1617
class VertexMixin:
@@ -68,6 +69,24 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:
6869

6970

7071
class VertexTunableMixin(BaseTunableMixin):
71-
# TODO
72+
_supported_tunable_models = ["text-bison@002"]
73+
74+
def _set_hyperparameters(self, base_model: str, n_update_steps: int, **kwargs):
75+
self.verify_model_is_supported(base_model)
76+
self.base_model = base_model
77+
self.n_update_steps = n_update_steps
78+
79+
def verify_model_is_supported(self, model: str):
80+
if model not in self._supported_tunable_models:
81+
raise ValueError(
82+
f"Model {model} is not supported. Supported models are"
83+
f" {self._supported_tunable_models}"
84+
)
85+
7286
def _tune(self, X: Any, y: Any):
73-
raise NotImplementedError("Tuning is not yet supported for Vertex AI.")
87+
df = pd.DataFrame({"input_text": X, "output_text": y})
88+
job = tune(self.base_model, df, self.n_update_steps)._job
89+
tuned_model = job.result()
90+
self.tuned_model_ = tuned_model._model_resource_name
91+
self.model = tuned_model
92+
return self

skllm/models/_base/text2text.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,79 @@
1-
pass
1+
from typing import Any, Union, List, Optional
2+
from abc import abstractmethod, ABC
3+
from numpy import ndarray
4+
from tqdm import tqdm
5+
import numpy as np
6+
import pandas as pd
7+
from skllm.utils import to_numpy as _to_numpy
8+
from sklearn.base import (
9+
BaseEstimator as _SklBaseEstimator,
10+
TransformerMixin as _SklTransformerMixin,
11+
)
12+
from skllm.llm.base import BaseTunableMixin as _BaseTunableMixin
13+
14+
15+
class BaseText2TextModel(ABC, _SklBaseEstimator, _SklTransformerMixin):
16+
def fit(self, X: Any, y: Any):
17+
return self
18+
19+
def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
20+
return self.transform(X)
21+
22+
def fit_transform(
23+
self,
24+
X: Union[np.ndarray, pd.Series, List[str]],
25+
y: Union[np.ndarray, pd.Series, List[str]],
26+
) -> ndarray:
27+
return self.fit(X, y).transform(X)
28+
29+
def transform(self, X: Union[np.ndarray, pd.Series, List[str]]):
30+
"""Predicts the class of each input.
31+
32+
Parameters
33+
----------
34+
X : Union[np.ndarray, pd.Series, List[str]]
35+
The input data to predict the class of.
36+
37+
Returns
38+
-------
39+
List[str]
40+
"""
41+
X = _to_numpy(X)
42+
predictions = []
43+
for i in tqdm(range(len(X))):
44+
predictions.append(self._predict_single(X[i]))
45+
return predictions
46+
47+
def _predict_single(self, x: Any) -> Any:
48+
prompt_dict = self._get_prompt(x)
49+
# this will be inherited from the LLM
50+
prediction = self._get_chat_completion(model=self.model, **prompt_dict)
51+
return prediction
52+
53+
@abstractmethod
54+
def _get_prompt(self, x: str) -> dict:
55+
"""Returns the prompt to use for a single input."""
56+
pass
57+
58+
59+
class BaseTunableText2TextModel(BaseText2TextModel):
60+
def fit(
61+
self,
62+
X: Union[np.ndarray, pd.Series, List[str]],
63+
y: Union[np.ndarray, pd.Series, List[str]],
64+
):
65+
if not isinstance(self, _BaseTunableMixin):
66+
raise TypeError(
67+
"Classifier must be mixed with a skllm.llm.base.BaseTunableMixin class"
68+
)
69+
self._tune(X, y)
70+
return self
71+
72+
def _get_prompt(self, x: str) -> dict:
73+
"""Returns the prompt to use for a single input."""
74+
return str(x)
75+
76+
def _predict_single(self, x: str) -> str:
77+
if self.model is None:
78+
raise RuntimeError("Model has not been tuned yet")
79+
return super()._predict_single(x)

skllm/models/gpt/classification/tunable.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,11 @@
1010
from typing import Optional
1111

1212

13-
class _Tunable(_BaseTunableClassifier, _GPTClassifierMixin, _GPTTunableMixin):
14-
def _set_hyperparameters(self, base_model: str, n_epochs: int, custom_suffix: str):
15-
self.base_model = base_model
16-
self.n_epochs = n_epochs
17-
self.custom_suffix = custom_suffix
18-
if base_model not in self._supported_tunable_models:
19-
raise ValueError(
20-
f"Model {base_model} is not supported. Supported models are"
21-
f" {self._supported_tunable_models}"
22-
)
13+
class _TunableClassifier(_BaseTunableClassifier, _GPTClassifierMixin, _GPTTunableMixin):
14+
pass
2315

2416

25-
class GPTClassifier(_Tunable, _SingleLabelMixin):
17+
class GPTClassifier(_TunableClassifier, _SingleLabelMixin):
2618
def __init__(
2719
self,
2820
base_model: str = "gpt-3.5-turbo-0613",
@@ -40,7 +32,7 @@ def __init__(
4032
self._set_hyperparameters(base_model, n_epochs, custom_suffix)
4133

4234

43-
class MultiLabelGPTClassifier(_Tunable, _MultiLabelMixin):
35+
class MultiLabelGPTClassifier(_TunableClassifier, _MultiLabelMixin):
4436
def __init__(
4537
self,
4638
base_model: str = "gpt-3.5-turbo-0613",

skllm/models/gpt/text2text/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)