@@ -38,6 +38,8 @@ class DynamicFewShotGPTClassifier(_BaseZeroShotGPTClassifier):
3838 label will be chosen based on probabilities from the training set.
3939 memory_index : Optional[IndexConstructor], default : None
4040 The memory index constructor to use. If None, a SklearnMemoryIndex will be used.
41+ prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
42+ If None, the default prompt template will be used.
4143 """
4244
4345 def __init__ (
@@ -48,10 +50,12 @@ def __init__(
4850 openai_model : str = "gpt-3.5-turbo" ,
4951 default_label : str | None = "Random" ,
5052 memory_index : IndexConstructor | None = None ,
53+ prompt_template : str | None = None ,
5154 ):
5255 super ().__init__ (openai_key , openai_org , openai_model , default_label )
5356 self .n_examples = n_examples
5457 self .memory_index = memory_index
58+ self .prompt_template = prompt_template
5559
5660 def fit (
5761 self ,
@@ -96,6 +100,18 @@ def fit(
96100
97101 return self
98102
103+ def _get_prompt_template (self ) -> str :
104+ """Returns the prompt template to use.
105+
106+ Returns
107+ -------
108+ str
109+ prompt template
110+ """
111+ if self .prompt_template is None :
112+ return _TRAINING_SAMPLE_PROMPT_TEMPLATE
113+ return self .prompt_template
114+
99115 def _get_prompt (self , x : str ) -> str :
100116 """Generates the prompt for the given input.
101117
@@ -109,6 +125,7 @@ def _get_prompt(self, x: str) -> str:
109125 str
110126 final prompt
111127 """
128+ prompt_template = self ._get_prompt_template ()
112129 embedding = self .embedding_model_ .transform ([x ])
113130 training_data = []
114131 for cls in self .classes_ :
@@ -118,7 +135,7 @@ def _get_prompt(self, x: str) -> str:
118135 neighbors = [partition [i ] for i in neighbors [0 ]]
119136 training_data .extend (
120137 [
121- _TRAINING_SAMPLE_PROMPT_TEMPLATE .format (x = neighbor , label = cls )
138+ prompt_template .format (x = neighbor , label = cls )
122139 for neighbor in neighbors
123140 ]
124141 )
0 commit comments