|
1 | | -from typing import Optional, Union, List |
| 1 | +import json |
| 2 | +import uuid |
| 3 | +from typing import List, Optional, Union |
| 4 | + |
| 5 | +import numpy as np |
2 | 6 | import pandas as pd |
| 7 | + |
3 | 8 | from skllm.models._base import _BaseZeroShotGPTClassifier |
4 | | -from skllm.prompts.builders import build_zero_shot_prompt_slc |
5 | 9 | from skllm.openai.credentials import set_credentials |
6 | | -from skllm.openai.tuning import create_tuning_job, await_results, delete_file |
7 | | -import numpy as np |
8 | | -import json |
9 | | -import uuid |
| 10 | +from skllm.openai.tuning import await_results, create_tuning_job, delete_file |
| 11 | +from skllm.prompts.builders import ( |
| 12 | + build_zero_shot_prompt_mlc, |
| 13 | + build_zero_shot_prompt_slc, |
| 14 | +) |
| 15 | + |
| 16 | +from skllm.utils import extract_json_key |
| 17 | + |
| 18 | +_TRAINING_SAMPLE_PROMPT_TEMPLATE = """ |
| 19 | +Sample input: |
| 20 | +```{x}``` |
| 21 | +
|
| 22 | +Sample target: {label} |
| 23 | +""" |
10 | 24 |
|
11 | 25 |
|
12 | 26 | def _build_clf_example( |
@@ -111,6 +125,126 @@ def fit( |
111 | 125 | return self |
112 | 126 |
|
113 | 127 |
|
| 128 | +class MultiLabelGPTClassifier(_BaseZeroShotGPTClassifier, _Tunable): |
| 129 | + """Fine-tunable GPT classifier for multi-label classification.""" |
| 130 | + |
| 131 | + supported_models = ["gpt-3.5-turbo-0613"] |
| 132 | + |
| 133 | + def __init__( |
| 134 | + self, |
| 135 | + base_model: str = "gpt-3.5-turbo-0613", |
| 136 | + default_label: Optional[str] = "Random", |
| 137 | + openai_key: Optional[str] = None, |
| 138 | + openai_org: Optional[str] = None, |
| 139 | + n_epochs: Optional[int] = None, |
| 140 | + custom_suffix: Optional[str] = "skllm", |
| 141 | + max_labels: int = 3, |
| 142 | + ): |
| 143 | + self.base_model = base_model |
| 144 | + self.n_epochs = n_epochs |
| 145 | + self.custom_suffix = custom_suffix |
| 146 | + if max_labels < 2: |
| 147 | + raise ValueError("max_labels should be at least 2") |
| 148 | + if isinstance(default_label, str) and default_label != "Random": |
| 149 | + raise ValueError("default_label should be a list of strings or 'Random'") |
| 150 | + self.max_labels = max_labels |
| 151 | + |
| 152 | + if base_model not in self.supported_models: |
| 153 | + raise ValueError( |
| 154 | + f"Model {base_model} is not supported. Supported models are" |
| 155 | + f" {self.supported_models}" |
| 156 | + ) |
| 157 | + super().__init__( |
| 158 | + openai_model="undefined", |
| 159 | + default_label=default_label, |
| 160 | + openai_key=openai_key, |
| 161 | + openai_org=openai_org, |
| 162 | + ) |
| 163 | + |
| 164 | + def _get_prompt(self, x: str) -> str: |
| 165 | + """Generates the prompt for the given input. |
| 166 | +
|
| 167 | + Parameters |
| 168 | + ---------- |
| 169 | + x : str |
| 170 | + sample |
| 171 | +
|
| 172 | + Returns |
| 173 | + ------- |
| 174 | + str |
| 175 | + final prompt |
| 176 | + """ |
| 177 | + return build_zero_shot_prompt_mlc( |
| 178 | + x=x, |
| 179 | + labels=repr(self.classes_), |
| 180 | + max_cats=self.max_labels, |
| 181 | + ) |
| 182 | + |
| 183 | + def _extract_labels(self, y) -> List[str]: |
| 184 | + """Extracts the labels into a list. |
| 185 | +
|
| 186 | + Parameters |
| 187 | + ---------- |
| 188 | + y : Any |
| 189 | +
|
| 190 | + Returns |
| 191 | + ------- |
| 192 | + List[str] |
| 193 | + """ |
| 194 | + labels = [] |
| 195 | + for l in y: |
| 196 | + for j in l: |
| 197 | + labels.append(j) |
| 198 | + return labels |
| 199 | + |
| 200 | + def _predict_single(self, x): |
| 201 | + """Predicts the labels for a single sample.""" |
| 202 | + completion = self._get_chat_completion(x) |
| 203 | + try: |
| 204 | + labels = extract_json_key( |
| 205 | + completion["choices"][0]["message"]["content"], "label" |
| 206 | + ) |
| 207 | + if not isinstance(labels, list): |
| 208 | + labels = labels.split(",") |
| 209 | + labels = [l.strip() for l in labels] |
| 210 | + except Exception as e: |
| 211 | + print(completion) |
| 212 | + print(f"Could not extract the label from the completion: {str(e)}") |
| 213 | + labels = [] |
| 214 | + |
| 215 | + labels = list(filter(lambda l: l in self.classes_, labels)) |
| 216 | + if len(labels) == 0: |
| 217 | + labels = self._get_default_label() |
| 218 | + if labels is not None and len(labels) > self.max_labels: |
| 219 | + labels = labels[: self.max_labels - 1] |
| 220 | + return labels |
| 221 | + |
| 222 | + def fit( |
| 223 | + self, |
| 224 | + X: Union[np.ndarray, pd.Series, List[str]], |
| 225 | + y: List[List[str]], |
| 226 | + ): |
| 227 | + """Fits the model to the given data. |
| 228 | +
|
| 229 | + Parameters |
| 230 | + ---------- |
| 231 | + X : Union[np.ndarray, pd.Series, List[str]] |
| 232 | + training data |
| 233 | + y : List[List[str]] |
| 234 | + training labels |
| 235 | +
|
| 236 | + Returns |
| 237 | + ------- |
| 238 | + MultiLabelGPTClassifier |
| 239 | + self |
| 240 | + """ |
| 241 | + X = self._to_np(X) |
| 242 | + y = self._to_np(y) |
| 243 | + super().fit(X, y) |
| 244 | + self._tune(X, y) |
| 245 | + return self |
| 246 | + |
| 247 | + |
114 | 248 | # similarly to PaLM, this is not a classifier, but a quick way to re-use the code |
115 | 249 | # the hierarchy of classes will be reworked in the next releases |
116 | 250 | class GPT(_BaseZeroShotGPTClassifier, _Tunable): |
|
0 commit comments