Skip to content

Commit fd51a6e

Browse files
Added functionality for custom prompts
1 parent 7dadca5 commit fd51a6e

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

skllm/models/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class _BaseZeroShotGPTClassifier(BaseClassifier, _OAIMixin):
130130
default_label : Optional[Union[List[str], str]] , default : 'Random'
131131
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
132132
label will be chosen based on probabilities from the training set.
133+
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
134+
If None, the default prompt template will be used.
133135
"""
134136

135137
def __init__(
@@ -138,10 +140,12 @@ def __init__(
138140
openai_org: Optional[str] = None,
139141
openai_model: str = "gpt-3.5-turbo",
140142
default_label: Optional[Union[List[str], str]] = "Random",
143+
prompt_template: Optional[str] = None,
141144
):
142145
self._set_keys(openai_key, openai_org)
143146
self.openai_model = openai_model
144147
self.default_label = default_label
148+
self.prompt_template = prompt_template
145149

146150
@abstractmethod
147151
def _get_prompt(self, x: str) -> str:

skllm/models/gpt/gpt_dyn_few_shot_clf.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class DynamicFewShotGPTClassifier(_BaseZeroShotGPTClassifier):
3838
label will be chosen based on probabilities from the training set.
3939
memory_index : Optional[IndexConstructor], default : None
4040
The memory index constructor to use. If None, a SklearnMemoryIndex will be used.
41+
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
42+
If None, the default prompt template will be used.
4143
"""
4244

4345
def __init__(
@@ -48,10 +50,12 @@ def __init__(
4850
openai_model: str = "gpt-3.5-turbo",
4951
default_label: str | None = "Random",
5052
memory_index: IndexConstructor | None = None,
53+
prompt_template: str | None = None,
5154
):
5255
super().__init__(openai_key, openai_org, openai_model, default_label)
5356
self.n_examples = n_examples
5457
self.memory_index = memory_index
58+
self.prompt_template = prompt_template
5559

5660
def fit(
5761
self,
@@ -96,6 +100,18 @@ def fit(
96100

97101
return self
98102

103+
def _get_prompt_template(self) -> str:
104+
"""Returns the prompt template to use.
105+
106+
Returns
107+
-------
108+
str
109+
prompt template
110+
"""
111+
if self.prompt_template is None:
112+
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
113+
return self.prompt_template
114+
99115
def _get_prompt(self, x: str) -> str:
100116
"""Generates the prompt for the given input.
101117
@@ -109,6 +125,7 @@ def _get_prompt(self, x: str) -> str:
109125
str
110126
final prompt
111127
"""
128+
prompt_template = self._get_prompt_template()
112129
embedding = self.embedding_model_.transform([x])
113130
training_data = []
114131
for cls in self.classes_:
@@ -118,7 +135,7 @@ def _get_prompt(self, x: str) -> str:
118135
neighbors = [partition[i] for i in neighbors[0]]
119136
training_data.extend(
120137
[
121-
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls)
138+
prompt_template.format(x=neighbor, label=cls)
122139
for neighbor in neighbors
123140
]
124141
)

skllm/models/gpt/gpt_few_shot_clf.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ def fit(
4545
self.classes_, self.probabilities_ = self._get_unique_targets(y)
4646
return self
4747

48+
def _get_prompt_template(self) -> str:
49+
"""Returns the prompt template to use.
50+
51+
Returns
52+
-------
53+
str
54+
prompt template
55+
"""
56+
if self.prompt_template is None:
57+
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
58+
return self.prompt_template
59+
4860
def _get_prompt(self, x: str) -> str:
4961
"""Generates the prompt for the given input.
5062
@@ -58,10 +70,11 @@ def _get_prompt(self, x: str) -> str:
5870
str
5971
final prompt
6072
"""
73+
prompt_template = self._get_prompt_template()
6174
training_data = []
6275
for xt, yt in zip(*self.training_data_):
6376
training_data.append(
64-
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt)
77+
prompt_template.format(x=xt, label=yt)
6578
)
6679

6780
training_data_str = "\n".join(training_data)

skllm/models/gpt/gpt_zero_shot_clf.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
2828
default_label : Optional[str] , default : 'Random'
2929
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
3030
label will be chosen based on probabilities from the training set.
31+
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
32+
If None, the default prompt template will be used.
33+
3134
"""
3235

3336
def __init__(
@@ -36,11 +39,17 @@ def __init__(
3639
openai_org: Optional[str] = None,
3740
openai_model: str = "gpt-3.5-turbo",
3841
default_label: Optional[str] = "Random",
42+
prompt_template: Optional[str] = None,
3943
):
4044
super().__init__(openai_key, openai_org, openai_model, default_label)
45+
self.prompt_template = prompt_template
4146

4247
def _get_prompt(self, x) -> str:
43-
return build_zero_shot_prompt_slc(x, repr(self.classes_))
48+
if self.prompt_template is None:
49+
return build_zero_shot_prompt_slc(x, repr(self.classes_))
50+
return build_zero_shot_prompt_slc(
51+
x, repr(self.classes_), template=self.prompt_template
52+
)
4453

4554
def fit(
4655
self,

0 commit comments

Comments
 (0)