Skip to content

Commit c3ec570

Browse files
authored
Merge pull request #70 from KennethEnevoldsen/added-custom-prompt-functionality
Added functionality for custom prompts
2 parents 1c93c89 + f993cab commit c3ec570

File tree

8 files changed

+66
-19
lines changed

8 files changed

+66
-19
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,6 @@ test.py
162162
tmp.ipynb
163163
tmp.py
164164
*.pickle
165+
166+
# vscode
167+
.vscode/

skllm/models/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class _BaseZeroShotGPTClassifier(BaseClassifier, _OAIMixin):
130130
default_label : Optional[Union[List[str], str]] , default : 'Random'
131131
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
132132
label will be chosen based on probabilities from the training set.
133+
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
134+
If None, the default prompt template will be used.
133135
"""
134136

135137
def __init__(
@@ -138,10 +140,12 @@ def __init__(
138140
openai_org: Optional[str] = None,
139141
openai_model: str = "gpt-3.5-turbo",
140142
default_label: Optional[Union[List[str], str]] = "Random",
143+
prompt_template: Optional[str] = None,
141144
):
142145
self._set_keys(openai_key, openai_org)
143146
self.openai_model = openai_model
144147
self.default_label = default_label
148+
self.prompt_template = prompt_template
145149

146150
@abstractmethod
147151
def _get_prompt(self, x: str) -> str:

skllm/models/gpt/gpt_dyn_few_shot_clf.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class DynamicFewShotGPTClassifier(_BaseZeroShotGPTClassifier):
3838
label will be chosen based on probabilities from the training set.
3939
memory_index : Optional[IndexConstructor], default : None
4040
The memory index constructor to use. If None, a SklearnMemoryIndex will be used.
41+
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
42+
If None, the default prompt template will be used.
4143
"""
4244

4345
def __init__(
@@ -48,10 +50,12 @@ def __init__(
4850
openai_model: str = "gpt-3.5-turbo",
4951
default_label: str | None = "Random",
5052
memory_index: IndexConstructor | None = None,
53+
prompt_template: str | None = None,
5154
):
5255
super().__init__(openai_key, openai_org, openai_model, default_label)
5356
self.n_examples = n_examples
5457
self.memory_index = memory_index
58+
self.prompt_template = prompt_template
5559

5660
def fit(
5761
self,
@@ -96,6 +100,18 @@ def fit(
96100

97101
return self
98102

103+
def _get_prompt_template(self) -> str:
104+
"""Returns the prompt template to use.
105+
106+
Returns
107+
-------
108+
str
109+
prompt template
110+
"""
111+
if self.prompt_template is None:
112+
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
113+
return self.prompt_template
114+
99115
def _get_prompt(self, x: str) -> str:
100116
"""Generates the prompt for the given input.
101117
@@ -109,6 +125,7 @@ def _get_prompt(self, x: str) -> str:
109125
str
110126
final prompt
111127
"""
128+
prompt_template = self._get_prompt_template()
112129
embedding = self.embedding_model_.transform([x])
113130
training_data = []
114131
for cls in self.classes_:
@@ -118,7 +135,7 @@ def _get_prompt(self, x: str) -> str:
118135
neighbors = [partition[i] for i in neighbors[0]]
119136
training_data.extend(
120137
[
121-
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls)
138+
prompt_template.format(x=neighbor, label=cls)
122139
for neighbor in neighbors
123140
]
124141
)

skllm/models/gpt/gpt_few_shot_clf.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ def fit(
4646
self.classes_, self.probabilities_ = self._get_unique_targets(y)
4747
return self
4848

49+
def _get_prompt_template(self) -> str:
50+
"""Returns the prompt template to use.
51+
52+
Returns
53+
-------
54+
str
55+
prompt template
56+
"""
57+
if self.prompt_template is None:
58+
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
59+
return self.prompt_template
60+
4961
def _get_prompt(self, x: str) -> str:
5062
"""Generates the prompt for the given input.
5163
@@ -59,10 +71,11 @@ def _get_prompt(self, x: str) -> str:
5971
str
6072
final prompt
6173
"""
74+
prompt_template = self._get_prompt_template()
6275
training_data = []
6376
for xt, yt in zip(*self.training_data_):
6477
training_data.append(
65-
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt)
78+
prompt_template.format(x=xt, label=yt)
6679
)
6780

6881
training_data_str = "\n".join(training_data)

skllm/models/gpt/gpt_zero_shot_clf.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
2828
default_label : Optional[str] , default : 'Random'
2929
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
3030
label will be chosen based on probabilities from the training set.
31+
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
32+
If None, the default prompt template will be used.
33+
3134
"""
3235

3336
def __init__(
@@ -36,11 +39,17 @@ def __init__(
3639
openai_org: Optional[str] = None,
3740
openai_model: str = "gpt-3.5-turbo",
3841
default_label: Optional[str] = "Random",
42+
prompt_template: Optional[str] = None,
3943
):
4044
super().__init__(openai_key, openai_org, openai_model, default_label)
45+
self.prompt_template = prompt_template
4146

4247
def _get_prompt(self, x) -> str:
43-
return build_zero_shot_prompt_slc(x, repr(self.classes_))
48+
if self.prompt_template is None:
49+
return build_zero_shot_prompt_slc(x, repr(self.classes_))
50+
return build_zero_shot_prompt_slc(
51+
x, repr(self.classes_), template=self.prompt_template
52+
)
4453

4554
def fit(
4655
self,

tests/test_chatgpt.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import unittest
22
from unittest.mock import patch
33

4-
from skllm.openai.chatgpt import (
5-
construct_message,
6-
extract_json_key,
7-
get_chat_completion,
8-
)
4+
from skllm.openai.chatgpt import construct_message, get_chat_completion
5+
from skllm.utils import extract_json_key
96

107

118
class TestChatGPT(unittest.TestCase):
12-
139
@patch("skllm.openai.credentials.set_credentials")
1410
@patch("openai.ChatCompletion.create")
1511
def test_get_chat_completion(self, mock_create, mock_set_credentials):
@@ -21,9 +17,18 @@ def test_get_chat_completion(self, mock_create, mock_set_credentials):
2117

2218
result = get_chat_completion(messages, key, org, model)
2319

24-
self.assertTrue(mock_set_credentials.call_count <= 1, "set_credentials should be called at most once")
25-
self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception "
26-
"on the first call")
20+
self.assertTrue(
21+
mock_set_credentials.call_count <= 1,
22+
"set_credentials should be called at most once",
23+
)
24+
self.assertEqual(
25+
mock_create.call_count,
26+
2,
27+
(
28+
"ChatCompletion.create should be called twice due to an exception "
29+
"on the first call"
30+
),
31+
)
2732
self.assertEqual(result, "success")
2833

2934
def test_construct_message(self):
@@ -45,5 +50,5 @@ def test_extract_json_key(self):
4550
self.assertEqual(result_with_invalid_key, None)
4651

4752

48-
if __name__ == '__main__':
53+
if __name__ == "__main__":
4954
unittest.main()

tests/test_gpt_few_shot_clf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import numpy as np
77

8-
from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
9-
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
8+
from skllm import DynamicFewShotGPTClassifier, FewShotGPTClassifier
109

1110

1211
class TestFewShotGPTClassifier(unittest.TestCase):

tests/test_gpt_zero_shot_clf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55

66
import numpy as np
77

8-
from skllm.models.gpt_zero_shot_clf import (
9-
MultiLabelZeroShotGPTClassifier,
10-
ZeroShotGPTClassifier,
11-
)
8+
from skllm import MultiLabelZeroShotGPTClassifier, ZeroShotGPTClassifier
129

1310

1411
def _get_ret(label):

0 commit comments

Comments
 (0)