|
1 | | -pass |
| 1 | +from typing import Any, Union, List, Optional |
| 2 | +from abc import abstractmethod, ABC |
| 3 | +from numpy import ndarray |
| 4 | +from tqdm import tqdm |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +from skllm.utils import to_numpy as _to_numpy |
| 8 | +from sklearn.base import ( |
| 9 | + BaseEstimator as _SklBaseEstimator, |
| 10 | + TransformerMixin as _SklTransformerMixin, |
| 11 | +) |
| 12 | +from skllm.llm.base import BaseTunableMixin as _BaseTunableMixin |
| 13 | + |
| 14 | + |
| 15 | +class BaseText2TextModel(ABC, _SklBaseEstimator, _SklTransformerMixin): |
| 16 | + def fit(self, X: Any, y: Any): |
| 17 | + return self |
| 18 | + |
| 19 | + def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): |
| 20 | + return self.transform(X) |
| 21 | + |
| 22 | + def fit_transform( |
| 23 | + self, |
| 24 | + X: Union[np.ndarray, pd.Series, List[str]], |
| 25 | + y: Union[np.ndarray, pd.Series, List[str]], |
| 26 | + ) -> ndarray: |
| 27 | + return self.fit(X, y).transform(X) |
| 28 | + |
| 29 | + def transform(self, X: Union[np.ndarray, pd.Series, List[str]]): |
| 30 | + """Predicts the class of each input. |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + X : Union[np.ndarray, pd.Series, List[str]] |
| 35 | + The input data to predict the class of. |
| 36 | +
|
| 37 | + Returns |
| 38 | + ------- |
| 39 | + List[str] |
| 40 | + """ |
| 41 | + X = _to_numpy(X) |
| 42 | + predictions = [] |
| 43 | + for i in tqdm(range(len(X))): |
| 44 | + predictions.append(self._predict_single(X[i])) |
| 45 | + return predictions |
| 46 | + |
| 47 | + def _predict_single(self, x: Any) -> Any: |
| 48 | + prompt_dict = self._get_prompt(x) |
| 49 | + # this will be inherited from the LLM |
| 50 | + prediction = self._get_chat_completion(model=self.model, **prompt_dict) |
| 51 | + return prediction |
| 52 | + |
| 53 | + @abstractmethod |
| 54 | + def _get_prompt(self, x: str) -> dict: |
| 55 | + """Returns the prompt to use for a single input.""" |
| 56 | + pass |
| 57 | + |
| 58 | + |
| 59 | +class BaseTunableText2TextModel(BaseText2TextModel): |
| 60 | + def fit( |
| 61 | + self, |
| 62 | + X: Union[np.ndarray, pd.Series, List[str]], |
| 63 | + y: Union[np.ndarray, pd.Series, List[str]], |
| 64 | + ): |
| 65 | + if not isinstance(self, _BaseTunableMixin): |
| 66 | + raise TypeError( |
| 67 | + "Classifier must be mixed with a skllm.llm.base.BaseTunableMixin class" |
| 68 | + ) |
| 69 | + self._tune(X, y) |
| 70 | + return self |
| 71 | + |
| 72 | + def _get_prompt(self, x: str) -> dict: |
| 73 | + """Returns the prompt to use for a single input.""" |
| 74 | + return str(x) |
| 75 | + |
| 76 | + def _predict_single(self, x: str) -> str: |
| 77 | + if self.model is None: |
| 78 | + raise RuntimeError("Model has not been tuned yet") |
| 79 | + return super()._predict_single(x) |
0 commit comments