Skip to content

Commit a42db43

Browse files
authored
Merge pull request #58 from iryna-kondr/feature-gpt-tuning
gpt_tuning
2 parents 73a036a + c637b3f commit a42db43

File tree

4 files changed

+311
-1
lines changed

4 files changed

+311
-1
lines changed

README.md

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,15 @@ Note: as the model is not being re-trained, but uses the training data during in
202202

203203
### Dynamic Few-Shot Text Classification
204204

205+
_To use this feature, you need to install `annoy` library:_
206+
207+
```bash
208+
pip install scikit-llm[annoy]
209+
```
210+
205211
`DynamicFewShotGPTClassifier` dynamically selects N samples per class to include in the prompt. This allows the few-shot classifier to scale to datasets that are too large for the standard context window of LLMs.
206212

207-
*How does it work?*
213+
_How does it work?_
208214

209215
During fitting, the whole dataset is partitioned by class, vectorized, and stored.
210216

@@ -288,6 +294,67 @@ clf.fit(X_train, y_train_encoded)
288294
yh = clf.predict(X_test)
289295
```
290296

297+
### LLM Fine-Tuning
298+
299+
At the moment the following scenarios are supported for tuning:
300+
301+
- **Text classification**: the model is fine-tuned to predict a single label per sample. The following estimators are supported:
302+
- `skllm.models.palm.PaLMClassifier`
303+
- `skllm.models.gpt.GPTClassifier`
304+
- **Text to text**: the model is fine-tuned on arbitrary text input-output pairs. The following estimators are supported:
305+
- `skllm.models.palm.PaLM`
306+
- `skllm.models.gpt.GPT`
307+
308+
Example 1: Fine-tuning a PaLM model for text classification
309+
310+
```python
311+
from skllm.models.palm import PaLMClassifier
312+
clf = PaLMClassifier(n_update_steps=100)
313+
clf.fit(X_train, y_train) # y_train is a list of labels
314+
labels = clf.predict(X_test)
315+
```
316+
317+
Example 2: Fine-tuning a PaLM model for text to text tasks
318+
319+
```python
320+
from skllm.models.palm import PaLM
321+
clf = PaLM(n_update_steps=100)
322+
clf.fit(X_train, y_train) # y_train is any desired output text
323+
labels = clf.predict(X_test)
324+
```
325+
326+
_Note:_ PaLM models tuning requires a Vertex AI account. Please refer to our [official guide on Medium](https://medium.com/@iryna230520/fine-tune-google-palm-2-with-scikit-llm-d41b0aa673a5) for more details.
327+
328+
Example 3: Fine-tuning a GPT model for text classification
329+
330+
```python
331+
from skllm.models.gpt import GPTClassifier
332+
333+
clf = GPTClassifier(
334+
base_model = "gpt-3.5-turbo-0613",
335+
n_epochs = None, # int or None. When None, will be determined automatically by OpenAI
336+
default_label = "Random", # optional
337+
)
338+
339+
clf.fit(X_train, y_train) # y_train is a list of labels
340+
labels = clf.predict(X_test)
341+
```
342+
343+
Example 4: Fine-tuning a GPT model for text to text tasks
344+
345+
```python
346+
from skllm.models.gpt import GPTC
347+
348+
clf = GPT(
349+
base_model = "gpt-3.5-turbo-0613",
350+
n_epochs = None, # int or None. When None, will be determined automatically by OpenAI
351+
system_msg = "You are a text processing model."
352+
)
353+
354+
clf.fit(X_train, y_train) # y_train is any desired output text
355+
labels = clf.predict(X_test)
356+
```
357+
291358
### Text Summarization
292359

293360
GPT excels at performing summarization tasks. Therefore, we provide `GPTSummarizer` that can be used both as stand-alone estimator, or as a preprocessor (in this case we can make an analogy with a dimensionality reduction preprocessor).

skllm/models/gpt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
ZeroShotGPTClassifier,
55
MultiLabelZeroShotGPTClassifier,
66
)
7+
8+
from skllm.models.gpt.gpt import GPTClassifier, GPT

skllm/models/gpt/gpt.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from typing import Optional, Union, List
2+
import pandas as pd
3+
from skllm.models._base import _BaseZeroShotGPTClassifier
4+
from skllm.prompts.builders import build_zero_shot_prompt_slc
5+
from skllm.openai.credentials import set_credentials
6+
from skllm.openai.tuning import create_tuning_job, await_results, delete_file
7+
import numpy as np
8+
import json
9+
import uuid
10+
11+
12+
def _build_clf_example(
13+
x: str, y: str, system_msg="You are a text classification model."
14+
):
15+
sample = {
16+
"messages": [
17+
{"role": "system", "content": system_msg},
18+
{"role": "user", "content": x},
19+
{"role": "assistant", "content": y},
20+
]
21+
}
22+
return json.dumps(sample)
23+
24+
25+
class _Tunable:
26+
system_msg = "You are a text classification model."
27+
28+
def _build_label(self, label: str):
29+
return json.dumps({"label": label})
30+
31+
def _tune(self, X, y):
32+
file_uuid = str(uuid.uuid4())
33+
filename = f"skllm_{file_uuid}.jsonl"
34+
with open(filename, "w+") as f:
35+
for xi, yi in zip(X, y):
36+
f.write(
37+
_build_clf_example(
38+
self._get_prompt(xi), self._build_label(yi), self.system_msg
39+
)
40+
)
41+
f.write("\n")
42+
set_credentials(self._get_openai_key(), self._get_openai_org())
43+
job = create_tuning_job(
44+
self.base_model,
45+
filename,
46+
self.n_epochs,
47+
self.custom_suffix,
48+
)
49+
print(f"Created new tuning job. JOB_ID = {job['id']}")
50+
job = await_results(job["id"])
51+
self.openai_model = job["fine_tuned_model"]
52+
delete_file(job["training_file"])
53+
print(f"Finished training. Number of trained tokens: {job['trained_tokens']}.")
54+
55+
56+
class GPTClassifier(_BaseZeroShotGPTClassifier, _Tunable):
57+
"""Fine-tunable GPT classifier for single-label classification."""
58+
59+
supported_models = ["gpt-3.5-turbo-0613"]
60+
61+
def __init__(
62+
self,
63+
base_model: str = "gpt-3.5-turbo-0613",
64+
default_label: Optional[str] = "Random",
65+
openai_key: Optional[str] = None,
66+
openai_org: Optional[str] = None,
67+
n_epochs: Optional[int] = None,
68+
custom_suffix: Optional[str] = "skllm",
69+
):
70+
self.base_model = base_model
71+
self.n_epochs = n_epochs
72+
self.custom_suffix = custom_suffix
73+
if base_model not in self.supported_models:
74+
raise ValueError(
75+
f"Model {base_model} is not supported. Supported models are"
76+
f" {self.supported_models}"
77+
)
78+
super().__init__(
79+
openai_model="undefined",
80+
default_label=default_label,
81+
openai_key=openai_key,
82+
openai_org=openai_org,
83+
)
84+
85+
def _get_prompt(self, x: str) -> str:
86+
return build_zero_shot_prompt_slc(x, repr(self.classes_))
87+
88+
def fit(
89+
self,
90+
X: Union[np.ndarray, pd.Series, List[str]],
91+
y: Union[np.ndarray, pd.Series, List[str]],
92+
):
93+
"""Fits the model to the given data.
94+
95+
Parameters
96+
----------
97+
X : Union[np.ndarray, pd.Series, List[str]]
98+
training data
99+
y : Union[np.ndarray, pd.Series, List[str]]
100+
training labels
101+
102+
Returns
103+
-------
104+
GPTClassifier
105+
self
106+
"""
107+
X = self._to_np(X)
108+
y = self._to_np(y)
109+
super().fit(X, y)
110+
self._tune(X, y)
111+
return self
112+
113+
114+
# similarly to PaLM, this is not a classifier, but a quick way to re-use the code
115+
# the hierarchy of classes will be reworked in the next releases
116+
class GPT(_BaseZeroShotGPTClassifier, _Tunable):
117+
"""Fine-tunable GPT on arbitrary input-output pairs."""
118+
119+
supported_models = ["gpt-3.5-turbo-0613"]
120+
121+
def __init__(
122+
self,
123+
base_model: str = "gpt-3.5-turbo-0613",
124+
openai_key: Optional[str] = None,
125+
openai_org: Optional[str] = None,
126+
n_epochs: Optional[int] = None,
127+
custom_suffix: Optional[str] = "skllm",
128+
system_msg: Optional[str] = "You are a text processing model.",
129+
):
130+
self.base_model = base_model
131+
self.n_epochs = n_epochs
132+
self.custom_suffix = custom_suffix
133+
self.system_msg = system_msg
134+
if base_model not in self.supported_models:
135+
raise ValueError(
136+
f"Model {base_model} is not supported. Supported models are"
137+
f" {self.supported_models}"
138+
)
139+
super().__init__(
140+
openai_model="undefined", # this will be rewritten later
141+
default_label="Random", # just for compatibility
142+
openai_key=openai_key,
143+
openai_org=openai_org,
144+
)
145+
146+
def _get_prompt(self, x: str) -> str:
147+
return x
148+
149+
def _build_label(self, label: str):
150+
return label
151+
152+
def fit(
153+
self,
154+
X: Union[np.ndarray, pd.Series, List[str]],
155+
y: Union[np.ndarray, pd.Series, List[str]],
156+
):
157+
"""Fits the model to the given data.
158+
159+
Parameters
160+
----------
161+
X : Union[np.ndarray, pd.Series, List[str]]
162+
training data
163+
y : Union[np.ndarray, pd.Series, List[str]]
164+
training labels
165+
166+
Returns
167+
-------
168+
GPT
169+
self
170+
"""
171+
X = self._to_np(X)
172+
y = self._to_np(y)
173+
self._tune(X, y)
174+
return self

skllm/openai/tuning.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Optional
2+
import openai
3+
from time import sleep
4+
from datetime import datetime
5+
import os
6+
7+
8+
def create_tuning_job(
9+
model: str,
10+
training_file: str,
11+
n_epochs: Optional[str] = None,
12+
suffix: Optional[str] = None,
13+
):
14+
out = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune")
15+
print(f"Created new file. FILE_ID = {out['id']}")
16+
print(f"Waiting for file to be processed ...")
17+
while not wait_file_ready(out["id"]):
18+
sleep(5)
19+
# delete the training_file after it is uploaded
20+
os.remove(training_file)
21+
params = {
22+
"model": model,
23+
"training_file": out["id"],
24+
}
25+
if n_epochs is not None:
26+
params["hyperparameters"] = {"n_epochs": n_epochs}
27+
if suffix is not None:
28+
params["suffix"] = suffix
29+
return openai.FineTuningJob.create(**params)
30+
31+
32+
def await_results(job_id: str, check_interval: int = 120):
33+
while True:
34+
job = openai.FineTuningJob.retrieve(job_id)
35+
status = job["status"]
36+
if status == "succeeded":
37+
return job
38+
elif status == "failed" or status == "cancelled":
39+
print(job)
40+
raise RuntimeError(f"Tuning job failed with status {status}")
41+
else:
42+
now = datetime.now()
43+
print(
44+
f"[{now}] Waiting for tuning job to complete. Current status: {status}"
45+
)
46+
sleep(check_interval)
47+
48+
def delete_file(file_id:str):
49+
openai.File.delete(file_id)
50+
51+
def wait_file_ready(file_id):
52+
files = openai.File.list()["data"]
53+
found = False
54+
for file in files:
55+
if file["id"] == file_id:
56+
found = True
57+
if file["status"] == "processed":
58+
return True
59+
elif file["status"] in ["error", "deleting", "deleted"]:
60+
print(file)
61+
raise RuntimeError(
62+
f"File upload {file_id} failed with status {file['status']}"
63+
)
64+
else:
65+
return False
66+
if not found:
67+
raise RuntimeError(f"File {file_id} not found")

0 commit comments

Comments
 (0)