Skip to content

Commit 1c93c89

Browse files
authored
Merge pull request #63 from ashwinprasadme/dev
Added MultiLabelFewShotGPTClassifier and MultiLabel GPT Fine-tuning Support
2 parents 7dadca5 + 7a608fe commit 1c93c89

File tree

6 files changed

+335
-11
lines changed

6 files changed

+335
-11
lines changed

README.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,21 @@ While the api remains the same as for the zero shot classifier, there are a few
200200

201201
Note: as the model is not being re-trained, but uses the training data during inference, one could say that this is still a (different) zero-shot approach.
202202

203+
### Multi-Label Few-Shot Text Classification
204+
205+
Example:
206+
207+
```python
208+
from skllm.models.gpt.gpt_few_shot_clf import MultiLabelFewShotGPTClassifier
209+
from skllm.datasets import get_multilabel_classification_dataset
210+
211+
X, y = get_multilabel_classification_dataset()
212+
213+
clf = MultiLabelFewShotGPTClassifier(max_labels=2, openai_model="gpt-3.5-turbo")
214+
clf.fit(X, y)
215+
labels = clf.predict(X)
216+
```
217+
203218
### Dynamic Few-Shot Text Classification
204219

205220
_To use this feature, you need to install `annoy` library:_
@@ -340,7 +355,23 @@ clf.fit(X_train, y_train) # y_train is a list of labels
340355
labels = clf.predict(X_test)
341356
```
342357

343-
Example 4: Fine-tuning a GPT model for text to text tasks
358+
Example 4: Fine-tuning a GPT model for multi-label text classification
359+
360+
```python
361+
from skllm.models.gpt import MultiLabelGPTClassifier
362+
363+
clf = MultiLabelGPTClassifier(
364+
base_model = "gpt-3.5-turbo-0613",
365+
n_epochs = None, # int or None. When None, will be determined automatically by OpenAI
366+
default_label = "Random", # optional
367+
max_labels = 2,
368+
)
369+
370+
clf.fit(X_train, y_train)
371+
labels = clf.predict(X_test)
372+
```
373+
374+
Example 5: Fine-tuning a GPT model for text to text tasks
344375

345376
```python
346377
from skllm.models.gpt import GPT

skllm/models/gpt/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from skllm.models.gpt.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
2-
from skllm.models.gpt.gpt_few_shot_clf import FewShotGPTClassifier
2+
from skllm.models.gpt.gpt_few_shot_clf import (
3+
FewShotGPTClassifier,
4+
MultiLabelFewShotGPTClassifier,
5+
)
36
from skllm.models.gpt.gpt_zero_shot_clf import (
47
ZeroShotGPTClassifier,
58
MultiLabelZeroShotGPTClassifier,
69
)
710

8-
from skllm.models.gpt.gpt import GPTClassifier, GPT
11+
from skllm.models.gpt.gpt import GPTClassifier, GPT, MultiLabelGPTClassifier

skllm/models/gpt/gpt.py

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1-
from typing import Optional, Union, List
1+
import json
2+
import uuid
3+
from typing import List, Optional, Union
4+
5+
import numpy as np
26
import pandas as pd
7+
38
from skllm.models._base import _BaseZeroShotGPTClassifier
4-
from skllm.prompts.builders import build_zero_shot_prompt_slc
59
from skllm.openai.credentials import set_credentials
6-
from skllm.openai.tuning import create_tuning_job, await_results, delete_file
7-
import numpy as np
8-
import json
9-
import uuid
10+
from skllm.openai.tuning import await_results, create_tuning_job, delete_file
11+
from skllm.prompts.builders import (
12+
build_zero_shot_prompt_mlc,
13+
build_zero_shot_prompt_slc,
14+
)
15+
16+
from skllm.utils import extract_json_key
17+
18+
_TRAINING_SAMPLE_PROMPT_TEMPLATE = """
19+
Sample input:
20+
```{x}```
21+
22+
Sample target: {label}
23+
"""
1024

1125

1226
def _build_clf_example(
@@ -111,6 +125,126 @@ def fit(
111125
return self
112126

113127

128+
class MultiLabelGPTClassifier(_BaseZeroShotGPTClassifier, _Tunable):
129+
"""Fine-tunable GPT classifier for multi-label classification."""
130+
131+
supported_models = ["gpt-3.5-turbo-0613"]
132+
133+
def __init__(
134+
self,
135+
base_model: str = "gpt-3.5-turbo-0613",
136+
default_label: Optional[str] = "Random",
137+
openai_key: Optional[str] = None,
138+
openai_org: Optional[str] = None,
139+
n_epochs: Optional[int] = None,
140+
custom_suffix: Optional[str] = "skllm",
141+
max_labels: int = 3,
142+
):
143+
self.base_model = base_model
144+
self.n_epochs = n_epochs
145+
self.custom_suffix = custom_suffix
146+
if max_labels < 2:
147+
raise ValueError("max_labels should be at least 2")
148+
if isinstance(default_label, str) and default_label != "Random":
149+
raise ValueError("default_label should be a list of strings or 'Random'")
150+
self.max_labels = max_labels
151+
152+
if base_model not in self.supported_models:
153+
raise ValueError(
154+
f"Model {base_model} is not supported. Supported models are"
155+
f" {self.supported_models}"
156+
)
157+
super().__init__(
158+
openai_model="undefined",
159+
default_label=default_label,
160+
openai_key=openai_key,
161+
openai_org=openai_org,
162+
)
163+
164+
def _get_prompt(self, x: str) -> str:
165+
"""Generates the prompt for the given input.
166+
167+
Parameters
168+
----------
169+
x : str
170+
sample
171+
172+
Returns
173+
-------
174+
str
175+
final prompt
176+
"""
177+
return build_zero_shot_prompt_mlc(
178+
x=x,
179+
labels=repr(self.classes_),
180+
max_cats=self.max_labels,
181+
)
182+
183+
def _extract_labels(self, y) -> List[str]:
184+
"""Extracts the labels into a list.
185+
186+
Parameters
187+
----------
188+
y : Any
189+
190+
Returns
191+
-------
192+
List[str]
193+
"""
194+
labels = []
195+
for l in y:
196+
for j in l:
197+
labels.append(j)
198+
return labels
199+
200+
def _predict_single(self, x):
201+
"""Predicts the labels for a single sample."""
202+
completion = self._get_chat_completion(x)
203+
try:
204+
labels = extract_json_key(
205+
completion["choices"][0]["message"]["content"], "label"
206+
)
207+
if not isinstance(labels, list):
208+
labels = labels.split(",")
209+
labels = [l.strip() for l in labels]
210+
except Exception as e:
211+
print(completion)
212+
print(f"Could not extract the label from the completion: {str(e)}")
213+
labels = []
214+
215+
labels = list(filter(lambda l: l in self.classes_, labels))
216+
if len(labels) == 0:
217+
labels = self._get_default_label()
218+
if labels is not None and len(labels) > self.max_labels:
219+
labels = labels[: self.max_labels - 1]
220+
return labels
221+
222+
def fit(
223+
self,
224+
X: Union[np.ndarray, pd.Series, List[str]],
225+
y: List[List[str]],
226+
):
227+
"""Fits the model to the given data.
228+
229+
Parameters
230+
----------
231+
X : Union[np.ndarray, pd.Series, List[str]]
232+
training data
233+
y : List[List[str]]
234+
training labels
235+
236+
Returns
237+
-------
238+
MultiLabelGPTClassifier
239+
self
240+
"""
241+
X = self._to_np(X)
242+
y = self._to_np(y)
243+
super().fit(X, y)
244+
self._tune(X, y)
245+
return self
246+
247+
114248
# similarly to PaLM, this is not a classifier, but a quick way to re-use the code
115249
# the hierarchy of classes will be reworked in the next releases
116250
class GPT(_BaseZeroShotGPTClassifier, _Tunable):

skllm/models/gpt/gpt_few_shot_clf.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import List, Union
1+
from typing import List, Literal, Optional, Union
22

33
import numpy as np
44
import pandas as pd
55

66
from skllm.models._base import _BaseZeroShotGPTClassifier
7-
from skllm.prompts.builders import build_few_shot_prompt_slc
7+
from skllm.prompts.builders import build_few_shot_prompt_mlc, build_few_shot_prompt_slc
8+
from skllm.utils import extract_json_key
89
from skllm.utils import to_numpy as _to_numpy
910

1011
_TRAINING_SAMPLE_PROMPT_TEMPLATE = """
@@ -69,3 +70,104 @@ def _get_prompt(self, x: str) -> str:
6970
return build_few_shot_prompt_slc(
7071
x=x, training_data=training_data_str, labels=repr(self.classes_)
7172
)
73+
74+
75+
class MultiLabelFewShotGPTClassifier(_BaseZeroShotGPTClassifier):
76+
"""Few-shot multi-label classifier."""
77+
78+
def __init__(
79+
self,
80+
openai_key: Optional[str] = None,
81+
openai_org: Optional[str] = None,
82+
openai_model: str = "gpt-3.5-turbo",
83+
default_label: Optional[Union[List[str], Literal["Random"]]] = "Random",
84+
max_labels: int = 3,
85+
):
86+
super().__init__(openai_key, openai_org, openai_model, default_label)
87+
if max_labels < 2:
88+
raise ValueError("max_labels should be at least 2")
89+
if isinstance(default_label, str) and default_label != "Random":
90+
raise ValueError("default_label should be a list of strings or 'Random'")
91+
self.max_labels = max_labels
92+
93+
def _extract_labels(self, y) -> List[str]:
94+
"""Extracts the labels into a list.
95+
96+
Parameters
97+
----------
98+
y : Any
99+
100+
Returns
101+
-------
102+
List[str]
103+
"""
104+
labels = []
105+
for l in y:
106+
for j in l:
107+
labels.append(j)
108+
return labels
109+
110+
def fit(
111+
self,
112+
X: Union[np.ndarray, pd.Series, List[str]],
113+
y: List[List[str]],
114+
):
115+
"""Fits the model to the given data.
116+
117+
Parameters
118+
----------
119+
X : Union[np.ndarray, pd.Series, List[str]]
120+
training data
121+
y : Union[np.ndarray, pd.Series, List[str]]
122+
training labels
123+
124+
Returns
125+
-------
126+
FewShotGPTClassifier
127+
self
128+
"""
129+
if not len(X) == len(y):
130+
raise ValueError("X and y must have the same length.")
131+
X = _to_numpy(X)
132+
y = _to_numpy(y)
133+
self.training_data_ = (X, y)
134+
self.classes_, self.probabilities_ = self._get_unique_targets(y)
135+
return self
136+
137+
def _get_prompt(self, x) -> str:
138+
training_data = []
139+
for xt, yt in zip(*self.training_data_):
140+
training_data.append(
141+
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt)
142+
)
143+
144+
training_data_str = "\n".join(training_data)
145+
146+
return build_few_shot_prompt_mlc(
147+
x=x,
148+
training_data=training_data_str,
149+
labels=repr(self.classes_),
150+
max_cats=self.max_labels,
151+
)
152+
153+
def _predict_single(self, x):
154+
"""Predicts the labels for a single sample."""
155+
completion = self._get_chat_completion(x)
156+
try:
157+
labels = extract_json_key(
158+
completion["choices"][0]["message"]["content"], "label"
159+
)
160+
if not isinstance(labels, list):
161+
labels = labels.split(",")
162+
labels = [l.strip() for l in labels]
163+
except Exception as e:
164+
print(completion)
165+
print(f"Could not extract the label from the completion: {str(e)}")
166+
labels = []
167+
168+
labels = list(filter(lambda l: l in self.classes_, labels))
169+
if len(labels) == 0:
170+
labels = self._get_default_label()
171+
if labels is not None and len(labels) > self.max_labels:
172+
labels = labels[: self.max_labels - 1]
173+
return labels

skllm/prompts/builders.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from skllm.prompts.templates import (
44
FEW_SHOT_CLF_PROMPT_TEMPLATE,
5+
FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
56
FOCUSED_SUMMARY_PROMPT_TEMPLATE,
67
SUMMARY_PROMPT_TEMPLATE,
78
TRANSLATION_PROMPT_TEMPLATE,
@@ -61,6 +62,38 @@ def build_few_shot_prompt_slc(
6162
return template.format(x=x, labels=labels, training_data=training_data)
6263

6364

65+
def build_few_shot_prompt_mlc(
66+
x: str,
67+
labels: str,
68+
training_data: str,
69+
max_cats: Union[int, str],
70+
template: str = FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
71+
) -> str:
72+
"""Builds a prompt for few-shot single-label classification.
73+
74+
Parameters
75+
----------
76+
x : str
77+
sample to classify
78+
labels : str
79+
candidate labels in a list-like representation
80+
max_cats : Union[int,str]
81+
maximum number of categories to assign
82+
training_data : str
83+
training data to be used for few-shot learning
84+
template : str
85+
prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE
86+
87+
Returns
88+
-------
89+
str
90+
prepared prompt
91+
"""
92+
return template.format(
93+
x=x, labels=labels, training_data=training_data, max_cats=max_cats
94+
)
95+
96+
6497
def build_zero_shot_prompt_mlc(
6598
x: str,
6699
labels: str,

0 commit comments

Comments
 (0)