Skip to content

Commit 210d05e

Browse files
OKUA1iryna-kondr
andcommitted
vertex zero shot
Co-authored-by: Iryna Kondrashchenko <[email protected]>
1 parent 0837b4a commit 210d05e

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

skllm/llm/vertex/mixin.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Optional, Union, List, Any, Dict, Mapping
2+
from skllm.config import SKLLMConfig as _Config
3+
from skllm.llm.base import (
4+
BaseClassifierMixin,
5+
BaseEmbeddingMixin,
6+
BaseTextCompletionMixin,
7+
BaseTunableMixin,
8+
)
9+
from skllm.llm.vertex.completion import get_completion_chat_mode, get_completion
10+
from skllm.utils import extract_json_key
11+
import numpy as np
12+
from tqdm import tqdm
13+
import json
14+
15+
16+
class VertexMixin:
17+
pass
18+
19+
20+
class VertexTextCompletionMixin(BaseTextCompletionMixin):
21+
def _get_chat_completion(
22+
self,
23+
model: str,
24+
messages: Union[str, List[Dict[str, str]]],
25+
system_message: Optional[str],
26+
examples: Optional[List] = None,
27+
) -> str:
28+
if examples is not None:
29+
raise NotImplementedError(
30+
"Examples API is not yet supported for Vertex AI."
31+
)
32+
if not isinstance(messages, str):
33+
raise ValueError("Only messages as strings are supported.")
34+
if model.startswith("chat-"):
35+
completion = get_completion_chat_mode(model, system_message, messages)
36+
else:
37+
completion = get_completion(model, messages)
38+
return str(completion)
39+
40+
41+
class VertexClassifierMixin(BaseClassifierMixin, VertexTextCompletionMixin):
42+
def _extract_out_label(self, completion: str, **kwargs) -> Any:
43+
"""Extracts the label from a completion.
44+
45+
Parameters
46+
----------
47+
label : Mapping[str, Any]
48+
The label to extract.
49+
50+
Returns
51+
-------
52+
label : str
53+
"""
54+
print(completion)
55+
try:
56+
label = extract_json_key(str(completion), "label")
57+
except Exception as e:
58+
print(completion)
59+
print(f"Could not extract the label from the completion: {str(e)}")
60+
label = ""
61+
print(label)
62+
return label
63+
64+
65+
class VertexEmbeddingMixin(BaseEmbeddingMixin):
66+
def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:
67+
raise NotImplementedError("Embeddings are not yet supported for Vertex AI.")
68+
69+
70+
class VertexTunableMixin(BaseTunableMixin):
71+
# TODO
72+
def _tune(self, X: Any, y: Any):
73+
raise NotImplementedError("Tuning is not yet supported for Vertex AI.")

skllm/models/_base/classifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from skllm.memory.base import IndexConstructor
3232
from skllm.memory._sklearn_nn import SklearnMemoryIndex
3333
from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer
34+
import ast
3435

3536
_TRAINING_SAMPLE_PROMPT_TEMPLATE = """
3637
Sample input:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from skllm.llm.vertex.mixin import VertexClassifierMixin as _VertexClassifierMixin
2+
from skllm.models._base.classifier import (
3+
BaseZeroShotClassifier as _BaseZeroShotClassifier,
4+
SingleLabelMixin as _SingleLabelMixin,
5+
MultiLabelMixin as _MultiLabelMixin,
6+
)
7+
from typing import Optional
8+
9+
10+
class ZeroShotVertexClassifier(
11+
_BaseZeroShotClassifier, _SingleLabelMixin, _VertexClassifierMixin
12+
):
13+
def __init__(
14+
self,
15+
model: str = "text-bison@001",
16+
default_label: Optional[str] = "Random",
17+
prompt_template: Optional[str] = None,
18+
**kwargs,
19+
):
20+
super().__init__(
21+
model=model,
22+
default_label=default_label,
23+
prompt_template=prompt_template,
24+
**kwargs,
25+
)
26+
27+
28+
class MultiLabelZeroShotVertexClassifier(
29+
_BaseZeroShotClassifier, _MultiLabelMixin, _VertexClassifierMixin
30+
):
31+
def __init__(
32+
self,
33+
model: str = "text-bison@001",
34+
default_label: Optional[str] = "Random",
35+
prompt_template: Optional[str] = None,
36+
max_labels: Optional[int] = 5,
37+
**kwargs,
38+
):
39+
super().__init__(
40+
model=model,
41+
default_label=default_label,
42+
prompt_template=prompt_template,
43+
max_labels=max_labels,
44+
**kwargs,
45+
)

0 commit comments

Comments
 (0)