Skip to content

Commit 0837b4a

Browse files
OKUA1iryna-kondr
andcommitted
GPT tuning
Co-authored-by: Iryna Kondrashchenko <[email protected]>
1 parent 8cc24b3 commit 0837b4a

File tree

4 files changed

+107
-3
lines changed

4 files changed

+107
-3
lines changed

skllm/llm/gpt/mixin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:
197197

198198
# for now this works only with OpenAI
199199
class GPTTunableMixin(BaseTunableMixin):
200-
system_msg = "You are a text classification model."
200+
_supported_tunable_models = ["gpt-3.5-turbo-0613", "gpt-3.5-turbo"]
201201

202202
def _build_label(self, label: str):
203203
return json.dumps({"label": label})
@@ -211,9 +211,16 @@ def _tune(self, X, y):
211211
filename = f"skllm_{file_uuid}.jsonl"
212212
with open(filename, "w+") as f:
213213
for xi, yi in zip(X, y):
214+
prompt = self._get_prompt(xi)
215+
if not isinstance(prompt["messages"], str):
216+
raise ValueError(
217+
"Incompatible prompt. Use a prompt with a single message."
218+
)
214219
f.write(
215220
_build_clf_example(
216-
self._get_prompt(xi), self._build_label(yi), self.system_msg
221+
prompt["messages"],
222+
self._build_label(yi),
223+
prompt["system_message"],
217224
)
218225
)
219226
f.write("\n")

skllm/models/_base/classifier.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
2020
FEW_SHOT_CLF_PROMPT_TEMPLATE,
2121
FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
22+
ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE,
23+
ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE,
2224
)
2325
from skllm.prompts.builders import (
2426
build_zero_shot_prompt_slc,
@@ -134,7 +136,7 @@ def _extract_labels(self, y) -> List[str]:
134136
class BaseClassifier(ABC, _SklBaseEstimator, _SklClassifierMixin):
135137
def __init__(
136138
self,
137-
model: str,
139+
model: Optional[str], # model can initially be None for tunable estimators
138140
default_label: str = "Random",
139141
max_labels: Optional[int] = 5,
140142
prompt_template: Optional[str] = None,
@@ -452,3 +454,26 @@ def fit(
452454
super().fit(X, y)
453455
self._tune(X, y)
454456
return self
457+
458+
def _get_prompt_template(self) -> str:
459+
"""Returns the prompt template to use for a single input."""
460+
if self.prompt_template is not None:
461+
return self.prompt_template
462+
elif isinstance(self, SingleLabelMixin):
463+
return ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE
464+
return ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE
465+
466+
def _get_prompt(self, x: str) -> dict:
467+
"""Returns the prompt to use for a single input."""
468+
if isinstance(self, SingleLabelMixin):
469+
prompt = build_zero_shot_prompt_slc(
470+
x, repr(self.classes_), template=self._get_prompt_template()
471+
)
472+
else:
473+
prompt = build_zero_shot_prompt_mlc(
474+
x,
475+
repr(self.classes_),
476+
self.max_labels,
477+
template=self._get_prompt_template(),
478+
)
479+
return {"messages": prompt, "system_message": "You are a text classifier."}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from skllm.llm.gpt.mixin import (
2+
GPTClassifierMixin as _GPTClassifierMixin,
3+
GPTTunableMixin as _GPTTunableMixin,
4+
)
5+
from skllm.models._base.classifier import (
6+
BaseTunableClassifier as _BaseTunableClassifier,
7+
SingleLabelMixin as _SingleLabelMixin,
8+
MultiLabelMixin as _MultiLabelMixin,
9+
)
10+
from typing import Optional
11+
12+
13+
class _Tunable(_BaseTunableClassifier, _GPTClassifierMixin, _GPTTunableMixin):
14+
def _set_hyperparameters(self, base_model: str, n_epochs: int, custom_suffix: str):
15+
self.base_model = base_model
16+
self.n_epochs = n_epochs
17+
self.custom_suffix = custom_suffix
18+
if base_model not in self._supported_tunable_models:
19+
raise ValueError(
20+
f"Model {base_model} is not supported. Supported models are"
21+
f" {self._supported_tunable_models}"
22+
)
23+
24+
25+
class GPTClassifier(_Tunable, _SingleLabelMixin):
26+
def __init__(
27+
self,
28+
base_model: str = "gpt-3.5-turbo-0613",
29+
default_label: Optional[str] = "Random",
30+
key: Optional[str] = None,
31+
org: Optional[str] = None,
32+
n_epochs: Optional[int] = None,
33+
custom_suffix: Optional[str] = "skllm",
34+
prompt_template: Optional[str] = None,
35+
):
36+
super().__init__(
37+
model=None, default_label=default_label, prompt_template=prompt_template
38+
)
39+
self._set_keys(key, org)
40+
self._set_hyperparameters(base_model, n_epochs, custom_suffix)
41+
42+
43+
class MultiLabelGPTClassifier(_Tunable, _MultiLabelMixin):
44+
def __init__(
45+
self,
46+
base_model: str = "gpt-3.5-turbo-0613",
47+
default_label: Optional[str] = "Random",
48+
key: Optional[str] = None,
49+
org: Optional[str] = None,
50+
n_epochs: Optional[int] = None,
51+
custom_suffix: Optional[str] = "skllm",
52+
prompt_template: Optional[str] = None,
53+
max_labels: Optional[int] = 5,
54+
):
55+
super().__init__(
56+
model=None,
57+
default_label=default_label,
58+
prompt_template=prompt_template,
59+
max_labels=max_labels,
60+
)
61+
self._set_keys(key, org)
62+
self._set_hyperparameters(base_model, n_epochs, custom_suffix)

skllm/prompts/templates.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
Your JSON response:
1616
"""
1717

18+
ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE = """
19+
Classify the following text into one of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`.
20+
Text: ```{x}```
21+
"""
22+
23+
ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE = """
24+
Classify the following text into at least 1 but up to {max_cats} of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`.
25+
Text: ```{x}```
26+
"""
27+
1828
FEW_SHOT_CLF_PROMPT_TEMPLATE = """
1929
You will be provided with the following information:
2030
1. An arbitrary text sample. The sample is delimited with triple backticks.

0 commit comments

Comments
 (0)