|
13 | 13 | extract_json_key, |
14 | 14 | ) |
15 | 15 | from skllm.config import SKLLMConfig as _Config |
| 16 | +from skllm.utils import to_numpy as _to_numpy |
| 17 | +from skllm.openai.mixin import OpenAIMixin as _OAIMixin |
16 | 18 |
|
17 | | - |
18 | | -class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin): |
| 19 | +class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin): |
19 | 20 | def __init__( |
20 | 21 | self, |
21 | 22 | openai_key: Optional[str] = None, |
22 | 23 | openai_org: Optional[str] = None, |
23 | 24 | openai_model: str = "gpt-3.5-turbo", |
24 | 25 | ): |
25 | | - self.openai_key = openai_key |
26 | | - self.openai_org = openai_org |
| 26 | + self._set_keys(openai_key, openai_org) |
27 | 27 | self.openai_model = openai_model |
28 | 28 |
|
| 29 | + def _to_np(self, X): |
| 30 | + return _to_numpy(X) |
| 31 | + |
29 | 32 | def fit( |
30 | 33 | self, |
31 | 34 | X: Optional[Union[np.ndarray, pd.Series, List[str]]], |
32 | 35 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]], |
33 | 36 | ): |
34 | | - if isinstance(X, np.ndarray): |
35 | | - X = np.squeeze(X) |
| 37 | + X = self._to_np(X) |
36 | 38 | self.classes_, self.probabilities_ = self._get_unique_targets(y) |
37 | 39 | return self |
38 | 40 |
|
39 | 41 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): |
40 | | - if isinstance(X, np.ndarray): |
41 | | - X = np.squeeze(X) |
| 42 | + X = self._to_np(X) |
42 | 43 | predictions = [] |
43 | 44 | for i in tqdm(range(len(X))): |
44 | 45 | predictions.append(self._predict_single(X[i])) |
45 | 46 | return predictions |
46 | 47 |
|
47 | | - def _get_openai_key(self): |
48 | | - key = self.openai_key |
49 | | - if key is None: |
50 | | - key = _Config.get_openai_key() |
51 | | - if key is None: |
52 | | - raise RuntimeError("OpenAI key was not found") |
53 | | - return key |
54 | | - |
55 | | - def _get_openai_org(self): |
56 | | - key = self.openai_org |
57 | | - if key is None: |
58 | | - key = _Config.get_openai_org() |
59 | | - if key is None: |
60 | | - raise RuntimeError("OpenAI organization was not found") |
61 | | - return key |
62 | | - |
63 | 48 | @abstractmethod |
64 | 49 | def _extract_labels(self, y: Any) -> List[str]: |
65 | 50 | pass |
@@ -126,8 +111,7 @@ def fit( |
126 | 111 | X: Optional[Union[np.ndarray, pd.Series, List[str]]], |
127 | 112 | y: Union[np.ndarray, pd.Series, List[str]], |
128 | 113 | ): |
129 | | - if isinstance(y, np.ndarray): |
130 | | - y = np.squeeze(y) |
| 114 | + y = self._to_np(y) |
131 | 115 | return super().fit(X, y) |
132 | 116 |
|
133 | 117 |
|
|
0 commit comments