Skip to content

Commit 8cc24b3

Browse files
OKUA1iryna-kondr
andcommitted
v1 initial structure
Co-authored-by: Iryna Kondrashchenko <[email protected]>
1 parent cc79413 commit 8cc24b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1176
-1981
lines changed

skllm/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
# ordering is important here to prevent circular imports
2-
from skllm.models.gpt.gpt_zero_shot_clf import (
3-
MultiLabelZeroShotGPTClassifier,
4-
ZeroShotGPTClassifier,
5-
)
6-
from skllm.models.gpt.gpt_few_shot_clf import FewShotGPTClassifier
7-
from skllm.models.gpt.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier

skllm/google/completions.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

skllm/llm/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
5+
class BaseTextCompletionMixin(ABC):
6+
@abstractmethod
7+
def _get_chat_completion(self, **kwargs):
8+
"""Gets a chat completion from the LLM"""
9+
pass
10+
11+
12+
class BaseClassifierMixin(BaseTextCompletionMixin):
13+
@abstractmethod
14+
def _extract_out_label(self, completion: Any, **kwargs):
15+
"""Extracts the label from a completion"""
16+
pass
17+
18+
19+
class BaseEmbeddingMixin(ABC):
20+
@abstractmethod
21+
def _get_embeddings(self, **kwargs):
22+
"""Gets embeddings from the LLM"""
23+
pass
24+
25+
26+
class BaseTunableMixin(ABC):
27+
@abstractmethod
28+
def _tune(self, X: Any, y: Any):
29+
pass
File renamed without changes.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import openai
2+
from skllm.llm.gpt.clients.openai.credentials import (
3+
set_azure_credentials,
4+
set_credentials,
5+
)
6+
from skllm.utils import retry
7+
8+
9+
@retry(max_retries=3)
10+
def get_chat_completion(
11+
messages: dict,
12+
key: str,
13+
org: str,
14+
model: str = "gpt-3.5-turbo",
15+
api="openai",
16+
):
17+
"""Gets a chat completion from the OpenAI API.
18+
19+
Parameters
20+
----------
21+
messages : dict
22+
input messages to use.
23+
key : str
24+
The OPEN AI key to use.
25+
org : str
26+
The OPEN AI organization ID to use.
27+
model : str, optional
28+
The OPEN AI model to use. Defaults to "gpt-3.5-turbo".
29+
max_retries : int, optional
30+
The maximum number of retries to use. Defaults to 3.
31+
api : str
32+
The API to use. Must be one of "openai" or "azure". Defaults to "openai".
33+
34+
Returns
35+
-------
36+
completion : dict
37+
"""
38+
if api == "openai":
39+
set_credentials(key, org)
40+
model_dict = {"model": model}
41+
elif api == "azure":
42+
set_azure_credentials(key, org)
43+
model_dict = {"engine": model}
44+
else:
45+
raise ValueError("Invalid API")
46+
47+
completion = openai.ChatCompletion.create(
48+
temperature=0.0, messages=messages, **model_dict
49+
)
50+
return completion

skllm/openai/chatgpt.py renamed to skllm/llm/gpt/clients/openai/credentials.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,6 @@
1-
from time import sleep
2-
31
import openai
4-
5-
from skllm.openai.credentials import set_azure_credentials, set_credentials
6-
7-
8-
def construct_message(role: str, content: str) -> dict:
9-
"""Constructs a message for the OpenAI API.
10-
11-
Parameters
12-
----------
13-
role : str
14-
The role of the message. Must be one of "system", "user", or "assistant".
15-
content : str
16-
The content of the message.
17-
18-
Returns
19-
-------
20-
message : dict
21-
"""
22-
if role not in ("system", "user", "assistant"):
23-
raise ValueError("Invalid role")
24-
return {"role": role, "content": content}
2+
from skllm.config import SKLLMConfig as _Config
3+
from time import sleep
254

265

276
def get_chat_completion(
@@ -77,3 +56,38 @@ def get_chat_completion(
7756
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
7857
f" {error_msg}`"
7958
)
59+
60+
61+
def set_credentials(key: str, org: str) -> None:
62+
"""Set the OpenAI key and organization.
63+
64+
Parameters
65+
----------
66+
key : str
67+
The OpenAI key to use.
68+
org : str
69+
The OPEN AI organization ID to use.
70+
"""
71+
openai.api_key = key
72+
openai.organization = org
73+
openai.api_type = "open_ai"
74+
openai.api_version = None
75+
openai.api_base = "https://api.openai.com/v1"
76+
77+
78+
def set_azure_credentials(key: str, org: str) -> None:
79+
"""Sets OpenAI credentials for Azure.
80+
81+
Parameters
82+
----------
83+
key : str
84+
The OpenAI (Azure) key to use.
85+
org : str
86+
The OpenAI (Azure) organization ID to use.
87+
"""
88+
if not openai.api_type or not openai.api_type.startswith("azure"):
89+
openai.api_type = "azure"
90+
openai.api_key = key
91+
openai.organization = org
92+
openai.api_base = _Config.get_azure_api_base()
93+
openai.api_version = _Config.get_azure_api_version()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from skllm.llm.gpt.clients.openai.credentials import set_credentials
2+
from skllm.utils import retry
3+
import openai
4+
5+
6+
@retry(max_retries=3)
7+
def get_embedding(
8+
text: str,
9+
key: str,
10+
org: str,
11+
model: str = "text-embedding-ada-002",
12+
):
13+
"""
14+
Encodes a string and return the embedding for a string.
15+
16+
Parameters
17+
----------
18+
text : str
19+
The string to encode.
20+
key : str
21+
The OPEN AI key to use.
22+
org : str
23+
The OPEN AI organization ID to use.
24+
model : str, optional
25+
The model to use. Defaults to "text-embedding-ada-002".
26+
max_retries : int, optional
27+
The maximum number of retries to use. Defaults to 3.
28+
29+
Returns
30+
-------
31+
emb : list
32+
The GPT embedding for the string.
33+
"""
34+
set_credentials(key, org)
35+
text = [str(t).replace("\n", " ") for t in text]
36+
embeddings = []
37+
emb = openai.Embedding.create(input=text, model=model)
38+
for i in range(len(emb["data"])):
39+
e = emb["data"][i]["embedding"]
40+
if not isinstance(e, list):
41+
raise ValueError(
42+
f"Encountered unknown embedding format. Expected list, got {type(emb)}"
43+
)
44+
embeddings.append(e)
45+
return embeddings
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
1-
from skllm.gpt4all_client import get_chat_completion as _g4a_get_chat_completion
2-
from skllm.openai.chatgpt import get_chat_completion as _oai_get_chat_completion
1+
from skllm.llm.gpt.clients.openai.completion import (
2+
get_chat_completion as _oai_get_chat_completion,
3+
)
4+
from skllm.llm.gpt.clients.gpt4all.completion import (
5+
get_chat_completion as _g4a_get_chat_completion,
6+
)
37

48

59
def get_chat_completion(
610
messages: dict,
711
openai_key: str = None,
812
openai_org: str = None,
913
model: str = "gpt-3.5-turbo",
10-
max_retries: int = 3,
1114
):
12-
"""Gets a chat completion from the OpenAI API."""
15+
"""Gets a chat completion from the OpenAI compatible API."""
1316
if model.startswith("gpt4all::"):
1417
return _g4a_get_chat_completion(messages, model[9:])
1518
else:
1619
api = "azure" if model.startswith("azure::") else "openai"
1720
if api == "azure":
1821
model = model[7:]
1922
return _oai_get_chat_completion(
20-
messages, openai_key, openai_org, model, max_retries, api=api
23+
messages, openai_key, openai_org, model, api=api
2124
)

skllm/llm/gpt/embedding.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from skllm.llm.gpt.clients.openai.embedding import get_embedding as _oai_get_embedding
2+
3+
4+
def get_embedding(
5+
text: str,
6+
key: str,
7+
org: str,
8+
model: str = "text-embedding-ada-002",
9+
):
10+
"""
11+
Encodes a string and return the embedding for a string.
12+
13+
Parameters
14+
----------
15+
text : str
16+
The string to encode.
17+
key : str
18+
The OPEN AI key to use.
19+
org : str
20+
The OPEN AI organization ID to use.
21+
model : str, optional
22+
The model to use. Defaults to "text-embedding-ada-002".
23+
24+
Returns
25+
-------
26+
emb : list
27+
The GPT embedding for the string.
28+
"""
29+
if model.startswith("gpt4all::"):
30+
raise ValueError("GPT4All is not supported for embeddings")
31+
elif model.startswith("azure::"):
32+
raise ValueError("Azure is not supported for embeddings")
33+
return _oai_get_embedding(text, key, org, model)

0 commit comments

Comments
 (0)