Skip to content

Commit 157548d

Browse files
committed
added readme + minor code refactoring + demo data
1 parent 0034d7c commit 157548d

File tree

7 files changed

+257
-20
lines changed

7 files changed

+257
-20
lines changed

README.md

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,116 @@
1-
# scikit-llm
1+
<p align="center">
2+
<img src="https://github.com/iryna-kondr/scikit-llm/blob/main/logo.png?raw=true" height="200"/>
3+
</p>
4+
5+
# Scikit-LLM: Sklearn Meets Large Language Models
6+
7+
Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks.
8+
9+
## Installation 💾
10+
11+
```bash
12+
pip install scikit-llm
13+
```
14+
15+
## Documentation 📚
16+
17+
### Configuring OpenAI API Key
18+
At the moment Scikit-LLM is only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required.
19+
20+
```python
21+
from skllm.config import SKLLMConfig
22+
SKLLMConfig.set_openai_key("<YOUR_KEY>")
23+
SKLLMConfig.set_openai_org("<YOUR_ORGANISATION>")
24+
```
25+
26+
### Zero-Shot Text Classification
27+
28+
One of the powerful ChatGPT features is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive.
29+
30+
We provide a class `ZeroShotGPTClassifier` that allows to create such a model as a regular scikit-learn classifier.
31+
32+
Example 1: Training as a regular classifier
33+
```python
34+
from skllm import ZeroShotGPTClassifier
35+
from skllm.datasets import get_classification_dataset
36+
37+
# demo sentiment analysis dataset
38+
# labels: positive, negative, neutral
39+
X, y = get_classification_dataset()
40+
41+
clf = ZeroShotGPTClassifier(openai_model = "gpt-3.5-turbo")
42+
clf.fit(X, y)
43+
labels = clf.predict(X)
44+
```
45+
Scikit-LLM will automatically query the OpenAI API and transform the response into a regular list of labels.
46+
47+
Additionally, Scikit-LLM will ensure that the obtained response contains a valid label. If this is not the case, a label will be selected randomly (label probabilities are proportional to label occurrences in the training set).
48+
49+
Example 2: Training without labeled data
50+
51+
Since the training data is not strictly required, it can be fully ommited. The only thing that has to be provided is the list of candidate labels.
52+
53+
```python
54+
from skllm import ZeroShotGPTClassifier
55+
from skllm.datasets import get_classification_dataset
56+
57+
X, _ = get_classification_dataset()
58+
59+
clf = ZeroShotGPTClassifier()
60+
clf.fit(None, ['positive', 'negative', 'neutral'])
61+
labels = clf.predict(X)
62+
63+
```
64+
65+
### Multi-Label Zero-Shot Text Classification
66+
67+
With a class `MultiLabelZeroShotGPTClassifier` it is possible to perform the classification in multi-label setting, which means that each sample might be assigned to one or several distinct classes.
68+
69+
Example:
70+
71+
```python
72+
from skllm import MultiLabelZeroShotGPTClassifier
73+
from skllm.datasets import get_multilabel_classification_dataset
74+
75+
X, y = get_multilabel_classification_dataset()
76+
77+
clf = MultiLabelZeroShotGPTClassifier(max_labels=3)
78+
clf.fit(X, y)
79+
labels = clf.predict(X)
80+
```
81+
82+
Similarly to the `ZeroShotGPTClassifier` it is sufficient if only candidate labels are provided. However, this time the classifier expects `y` of a type `List[List[str]]`.
83+
84+
```
85+
from skllm import MultiLabelZeroShotGPTClassifier
86+
from skllm.datasets import get_multilabel_classification_dataset
87+
88+
X, _ = get_multilabel_classification_dataset()
89+
candidate_labels = [
90+
"Quality",
91+
"Price",
92+
"Delivery",
93+
"Service",
94+
"Product Variety",
95+
"Customer Support",
96+
"Packaging",
97+
"User Experience",
98+
"Return Policy",
99+
"Product Information"
100+
]
101+
clf = MultiLabelZeroShotGPTClassifier(max_labels=3)
102+
clf.fit(None, [candidate_labels])
103+
labels = clf.predict(X)
104+
```
105+
106+
## Roadmap 🧭
107+
108+
- [x] Zero-Shot Classification with OpenAI GPT 3/4
109+
- [x] Multiclass classification
110+
- [x] Multi-label classification
111+
- [x] ChatGPT models
112+
- [ ] InstructGPT models
113+
- [ ] Few shot classifier
114+
- [ ] GPT Vectorizer
115+
- [ ] GPT Fine-tuning (optional)
116+
- [ ] Integration of other LLMs

skllm/config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
import os
22
from typing import Optional
33

4-
_OPENAI_KEY_VAR = 'SLLM_CONFIG_OPENAI_KEY'
4+
_OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY"
5+
_OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG"
56

6-
class SLLMConfig():
7+
class SKLLMConfig():
78

89
@staticmethod
910
def set_openai_key(key: str) -> None:
1011
os.environ[_OPENAI_KEY_VAR] = key
1112

1213
@staticmethod
1314
def get_openai_key() -> Optional[str]:
14-
return os.environ.get(_OPENAI_KEY_VAR, None)
15+
return os.environ.get(_OPENAI_KEY_VAR, None)
16+
17+
@staticmethod
18+
def set_openai_org(key: str) -> None:
19+
os.environ[_OPENAI_ORG_VAR] = key
20+
21+
@staticmethod
22+
def get_openai_org() -> Optional[str]:
23+
return os.environ.get(_OPENAI_ORG_VAR, None)

skllm/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from skllm.datasets.multi_class import get_classification_dataset
2+
from skllm.datasets.multi_label import get_multilabel_classification_dataset

skllm/datasets/multi_class.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
def get_classification_dataset():
2+
X = [
3+
r"I was absolutely blown away by the performances in 'Summer's End'. The acting was top-notch, and the plot had me gripped from start to finish. A truly captivating cinematic experience that I would highly recommend.",
4+
r"The special effects in 'Star Battles: Nebula Conflict' were out of this world. I felt like I was actually in space. The storyline was incredibly engaging and left me wanting more. Excellent film.",
5+
r"'The Lost Symphony' was a masterclass in character development and storytelling. The score was hauntingly beautiful and complimented the intense, emotional scenes perfectly. Kudos to the director and cast for creating such a masterpiece.",
6+
r"I was pleasantly surprised by 'Love in the Time of Cholera'. The romantic storyline was heartwarming and the characters were incredibly realistic. The cinematography was also top-notch. A must-watch for all romance lovers.",
7+
r"I went into 'Marble Street' with low expectations, but I was pleasantly surprised. The suspense was well-maintained throughout, and the twist at the end was something I did not see coming. Bravo!",
8+
r"'The Great Plains' is a touching portrayal of life in rural America. The performances were heartfelt and the scenery was breathtaking. I was moved to tears by the end. It's a story that will stay with me for a long time.",
9+
r"The screenwriting in 'Under the Willow Tree' was superb. The dialogue felt real and the characters were well-rounded. The performances were also fantastic. I haven't enjoyed a movie this much in a while.",
10+
r"'Nightshade' is a brilliant take on the superhero genre. The protagonist was relatable and the villain was genuinely scary. The action sequences were thrilling and the storyline was engaging. I can't wait for the sequel.",
11+
r"The cinematography in 'Awakening' was nothing short of spectacular. The visuals alone are worth the ticket price. The storyline was unique and the performances were solid. An overall fantastic film.",
12+
r"'Eternal Embers' was a cinematic delight. The storytelling was original and the performances were exceptional. The director's vision was truly brought to life on the big screen. A must-see for all movie lovers.",
13+
r"I was thoroughly disappointed with 'Silver Shadows'. The plot was confusing and the performances were lackluster. I wouldn't recommend wasting your time on this one.",
14+
r"'The Darkened Path' was a disaster. The storyline was unoriginal, the acting was wooden and the special effects were laughably bad. Save your money and skip this one.",
15+
r"I had high hopes for 'The Final Frontier', but it failed to deliver. The plot was full of holes and the characters were poorly developed. It was a disappointing experience.",
16+
r"'The Fall of the Phoenix' was a letdown. The storyline was confusing and the characters were one-dimensional. I found myself checking my watch multiple times throughout the movie.",
17+
r"I regret wasting my time on 'Emerald City'. The plot was nonsensical and the performances were uninspired. It was a major disappointment.",
18+
r"I found 'Hollow Echoes' to be a complete mess. The plot was non-existent, the performances were overdone, and the pacing was all over the place. Definitely not worth the hype.",
19+
r"'Underneath the Stars' was a huge disappointment. The storyline was predictable and the acting was mediocre at best. I was expecting so much more.",
20+
r"I was left unimpressed by 'River's Edge'. The plot was convoluted, the characters were uninteresting, and the ending was unsatisfying. It's a pass for me.",
21+
r"The acting in 'Desert Mirage' was subpar, and the plot was boring. I found myself yawning multiple times throughout the movie. Save your time and skip this one.",
22+
r"'Crimson Dawn' was a major letdown. The plot was cliched and the characters were flat. The special effects were also poorly executed. I wouldn't recommend it.",
23+
r"'Remember the Days' was utterly forgettable. The storyline was dull, the performances were bland, and the dialogue was cringeworthy. A big disappointment.",
24+
r"'The Last Frontier' was simply okay. The plot was decent and the performances were acceptable. However, it lacked a certain spark to make it truly memorable.",
25+
r"'Through the Storm' was not bad, but it wasn't great either. The storyline was somewhat predictable, and the characters were somewhat stereotypical. It was an average movie at best.",
26+
r"I found 'After the Rain' to be pretty average. The plot was okay and the performances were decent, but it didn't leave a lasting impression on me.",
27+
r"'Beyond the Horizon' was neither good nor bad. The plot was interesting enough, but the characters were not very well developed. It was an okay watch.",
28+
r"'The Silent Echo' was a mediocre movie. The storyline was passable and the performances were fair, but it didn't stand out in any way.",
29+
r"I thought 'The Scent of Roses' was pretty average. The plot was somewhat engaging, and the performances were okay, but it didn't live up to my expectations.",
30+
r"'Under the Same Sky' was an okay movie. The plot was decent, and the performances were fine, but it lacked depth and originality. It's not a movie I would watch again.",
31+
r"'Chasing Shadows' was fairly average. The plot was not bad, and the performances were passable, but it lacked a certain spark. It was just okay.",
32+
r"'Beneath the Surface' was pretty run-of-the-mill. The plot was decent, the performances were okay, but it wasn't particularly memorable. It was an okay movie.",
33+
]
34+
35+
36+
y = (
37+
["positive" for _ in range(10)]
38+
+ ["negative" for _ in range(10)]
39+
+ ["neutral" for _ in range(10)]
40+
)
41+
42+
return X, y

skllm/datasets/multi_label.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
def get_multilabel_classification_dataset():
2+
X = [
3+
"The product was of excellent quality, and the packaging was also very good. Highly recommend!",
4+
"The delivery was super fast, but the product did not match the information provided on the website.",
5+
"Great variety of products, but the customer support was quite unresponsive.",
6+
"Affordable prices and an easy-to-use website. A great shopping experience overall.",
7+
"The delivery was delayed, and the packaging was damaged. Not a good experience.",
8+
"Excellent customer support, but the return policy is quite complicated.",
9+
"The product was not as described. However, the return process was easy and quick.",
10+
"Great service and fast delivery. The product was also of high quality.",
11+
"The prices are a bit high. However, the product quality and user experience are worth it.",
12+
"The website provides detailed information about products. The delivery was also very fast."
13+
]
14+
15+
y = [
16+
["Quality", "Packaging"],
17+
["Delivery", "Product Information"],
18+
["Product Variety", "Customer Support"],
19+
["Price", "User Experience"],
20+
["Delivery", "Packaging"],
21+
["Customer Support", "Return Policy"],
22+
["Product Information", "Return Policy"],
23+
["Service", "Delivery", "Quality"],
24+
["Price", "Quality", "User Experience"],
25+
["Product Information", "Delivery"],
26+
]
27+
28+
return X, y

skllm/models/gpt_zero_shot_clf.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from abc import ABC, abstractmethod
88
from sklearn.base import BaseEstimator, ClassifierMixin
99
from skllm.openai.prompts import get_zero_shot_prompt_slc, get_zero_shot_prompt_mlc
10-
from skllm.openai.chatgpt import construct_message, get_chat_completion, extract_json_key
10+
from skllm.openai.chatgpt import (
11+
construct_message,
12+
get_chat_completion,
13+
extract_json_key,
14+
)
15+
from skllm.config import SKLLMConfig as _Config
1116

1217

1318
class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin):
@@ -23,20 +28,37 @@ def __init__(
2328

2429
def fit(
2530
self,
26-
X: Union[np.ndarray, pd.Series, List[str]],
31+
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
2732
y: Union[np.ndarray, pd.Series, List[str], List[List[str]]],
2833
):
34+
if isinstance(X, np.ndarray):
35+
X = np.squeeze(X)
2936
self.classes_, self.probabilities_ = self._get_unique_targets(y)
3037
return self
3138

32-
def predict(self, X):
39+
def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
40+
if isinstance(X, np.ndarray):
41+
X = np.squeeze(X)
3342
predictions = []
3443
for i in tqdm(range(len(X))):
3544
predictions.append(self._predict_single(X[i]))
3645
return predictions
3746

38-
def _get_openai_keys(self):
39-
return self.openai_key, self.openai_org
47+
def _get_openai_key(self):
48+
key = self.openai_key
49+
if key is None:
50+
key = _Config.get_openai_key()
51+
if key is None:
52+
raise RuntimeError("OpenAI key was not found")
53+
return key
54+
55+
def _get_openai_org(self):
56+
key = self.openai_org
57+
if key is None:
58+
key = _Config.get_openai_org()
59+
if key is None:
60+
raise RuntimeError("OpenAI organization was not found")
61+
return key
4062

4163
@abstractmethod
4264
def _extract_labels(self, y: Any) -> List[str]:
@@ -56,13 +78,13 @@ def _get_unique_targets(self, y):
5678

5779
return classes, probs
5880

59-
def _get_completion(self, x):
81+
def _get_chat_completion(self, x):
6082
prompt = self._get_prompt(x)
6183
msgs = []
6284
msgs.append(construct_message("system", "You are a text classification model."))
6385
msgs.append(construct_message("user", prompt))
6486
completion = get_chat_completion(
65-
msgs, self.openai_key, self.openai_org, self.openai_model
87+
msgs, self._get_openai_key(), self._get_openai_org(), self.openai_model
6688
)
6789
return completion
6890

@@ -87,7 +109,7 @@ def _get_prompt(self, x) -> str:
87109
return get_zero_shot_prompt_slc(x, self.classes_)
88110

89111
def _predict_single(self, x):
90-
completion = self._get_completion(x)
112+
completion = self._get_chat_completion(x)
91113
try:
92114
label = str(
93115
extract_json_key(completion.choices[0].message["content"], "label")
@@ -99,6 +121,15 @@ def _predict_single(self, x):
99121
label = random.choices(self.classes_, self.probabilities_)[0]
100122
return label
101123

124+
def fit(
125+
self,
126+
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
127+
y: Union[np.ndarray, pd.Series, List[str]],
128+
):
129+
if isinstance(y, np.ndarray):
130+
y = np.squeeze(y)
131+
return super().fit(X, y)
132+
102133

103134
class MultiLabelZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
104135
def __init__(
@@ -125,7 +156,7 @@ def _get_prompt(self, x) -> str:
125156
return get_zero_shot_prompt_mlc(x, self.classes_, self.max_labels)
126157

127158
def _predict_single(self, x):
128-
completion = self._get_completion(x)
159+
completion = self._get_chat_completion(x)
129160
try:
130161
labels = extract_json_key(completion.choices[0].message["content"], "label")
131162
if not isinstance(labels, list):
@@ -139,4 +170,11 @@ def _predict_single(self, x):
139170
labels = labels[: self.max_labels - 1]
140171
elif len(labels) < 1:
141172
labels = [random.choices(self.classes_, self.probabilities_)[0]]
142-
return labels
173+
return labels
174+
175+
def fit(
176+
self,
177+
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
178+
y: List[List[str]],
179+
):
180+
return super().fit(X, y)

skllm/openai/chatgpt.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@ def construct_message(role, content):
66
raise ValueError("Invalid role")
77
return {"role": role, "content": content}
88

9-
def get_chat_completion(messages, key, org, model="gpt-3.5-turbo"):
9+
def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries = 3):
1010
openai.api_key = key
1111
openai.organization = org
12-
completion = openai.ChatCompletion.create(
13-
model=model, temperature=0., messages=messages
14-
)
15-
16-
return completion
12+
for _ in range(max_retries):
13+
try:
14+
completion = openai.ChatCompletion.create(
15+
model=model, temperature=0., messages=messages
16+
)
17+
return completion
18+
except Exception:
19+
continue
1720

1821
def extract_json_key(json_, key):
1922
try:

0 commit comments

Comments
 (0)