Skip to content

Commit 5fbecdc

Browse files
committed
openai v1.2
1 parent 210d05e commit 5fbecdc

File tree

6 files changed

+28
-82
lines changed

6 files changed

+28
-82
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ build-backend = "setuptools.build_meta"
66
dependencies = [
77
"scikit-learn>=1.1.0",
88
"pandas>=1.5.0",
9-
"openai>=0.27.9",
9+
"openai>=1.2.0",
1010
"tqdm>=4.60.0",
1111
"google-cloud-aiplatform>=1.27.0"
1212
]
1313
name = "scikit-llm"
14-
version = "0.4.1"
14+
version = "1.0.0a1"
1515
authors = [
1616
{ name="Oleg Kostromin", email="[email protected]" },
1717
{ name="Iryna Kondrashchenko", email="[email protected]" },

skllm/llm/gpt/clients/openai/completion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import openai
2+
from openai import OpenAI
23
from skllm.llm.gpt.clients.openai.credentials import (
34
set_azure_credentials,
45
set_credentials,
@@ -36,15 +37,14 @@ def get_chat_completion(
3637
completion : dict
3738
"""
3839
if api == "openai":
39-
set_credentials(key, org)
40-
model_dict = {"model": model}
40+
client = set_credentials(key, org)
41+
model_dict = {"model": model, "response_format": {"type": "json_object"}}
4142
elif api == "azure":
42-
set_azure_credentials(key, org)
43-
model_dict = {"engine": model}
43+
client = set_azure_credentials(key, org)
44+
model_dict = {"model": model}
4445
else:
4546
raise ValueError("Invalid API")
46-
47-
completion = openai.ChatCompletion.create(
47+
completion = client.chat.completions.create(
4848
temperature=0.0, messages=messages, **model_dict
4949
)
5050
return completion
Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,7 @@
11
import openai
22
from skllm.config import SKLLMConfig as _Config
33
from time import sleep
4-
5-
6-
def get_chat_completion(
7-
messages: dict,
8-
key: str,
9-
org: str,
10-
model: str = "gpt-3.5-turbo",
11-
max_retries: int = 3,
12-
api="openai",
13-
):
14-
"""Gets a chat completion from the OpenAI API.
15-
16-
Parameters
17-
----------
18-
messages : dict
19-
input messages to use.
20-
key : str
21-
The OPEN AI key to use.
22-
org : str
23-
The OPEN AI organization ID to use.
24-
model : str, optional
25-
The OPEN AI model to use. Defaults to "gpt-3.5-turbo".
26-
max_retries : int, optional
27-
The maximum number of retries to use. Defaults to 3.
28-
api : str
29-
The API to use. Must be one of "openai" or "azure". Defaults to "openai".
30-
31-
Returns
32-
-------
33-
completion : dict
34-
"""
35-
if api == "openai":
36-
set_credentials(key, org)
37-
model_dict = {"model": model}
38-
elif api == "azure":
39-
set_azure_credentials(key, org)
40-
model_dict = {"engine": model}
41-
else:
42-
raise ValueError("Invalid API")
43-
error_msg = None
44-
error_type = None
45-
for _ in range(max_retries):
46-
try:
47-
completion = openai.ChatCompletion.create(
48-
temperature=0.0, messages=messages, **model_dict
49-
)
50-
return completion
51-
except Exception as e:
52-
error_msg = str(e)
53-
error_type = type(e).__name__
54-
sleep(3)
55-
print(
56-
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
57-
f" {error_msg}`"
58-
)
59-
4+
from openai import OpenAI, AzureOpenAI
605

616
def set_credentials(key: str, org: str) -> None:
627
"""Set the OpenAI key and organization.
@@ -68,12 +13,8 @@ def set_credentials(key: str, org: str) -> None:
6813
org : str
6914
The OPEN AI organization ID to use.
7015
"""
71-
openai.api_key = key
72-
openai.organization = org
73-
openai.api_type = "open_ai"
74-
openai.api_version = None
75-
openai.api_base = "https://api.openai.com/v1"
76-
16+
client = OpenAI(api_key=key, organization=org)
17+
return client
7718

7819
def set_azure_credentials(key: str, org: str) -> None:
7920
"""Sets OpenAI credentials for Azure.
@@ -85,9 +26,6 @@ def set_azure_credentials(key: str, org: str) -> None:
8526
org : str
8627
The OpenAI (Azure) organization ID to use.
8728
"""
88-
if not openai.api_type or not openai.api_type.startswith("azure"):
89-
openai.api_type = "azure"
90-
openai.api_key = key
91-
openai.organization = org
92-
openai.api_base = _Config.get_azure_api_base()
93-
openai.api_version = _Config.get_azure_api_version()
29+
client = AzureOpenAI(api_key=key, organization=org, api_version=_Config.get_azure_api_version(), azure_endpoint = _Config.get_azure_api_base())
30+
return client
31+

skllm/llm/gpt/clients/openai/embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from skllm.llm.gpt.clients.openai.credentials import set_credentials
22
from skllm.utils import retry
33
import openai
4+
from openai import OpenAI
45

56

67
@retry(max_retries=3)
@@ -31,10 +32,11 @@ def get_embedding(
3132
emb : list
3233
The GPT embedding for the string.
3334
"""
34-
set_credentials(key, org)
35+
client = OpenAI()
36+
set_credentials(client, key, org)
3537
text = [str(t).replace("\n", " ") for t in text]
3638
embeddings = []
37-
emb = openai.Embedding.create(input=text, model=model)
39+
emb = client.embeddings.create(input=text, model=model)
3840
for i in range(len(emb["data"])):
3941
e = emb["data"][i]["embedding"]
4042
if not isinstance(e, list):

skllm/llm/gpt/mixin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,12 @@ def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> Any:
151151
label : str
152152
"""
153153
try:
154-
label = extract_json_key(
155-
completion["choices"][0]["message"]["content"], "label"
156-
)
154+
if hasattr(completion, "__getitem__"):
155+
label = extract_json_key(
156+
completion["choices"][0]["message"]["content"], "label"
157+
)
158+
else:
159+
label = extract_json_key(completion.choices[0].message.content, "label")
157160
except Exception as e:
158161
print(completion)
159162
print(f"Could not extract the label from the completion: {str(e)}")

skllm/models/_base/classifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def _extract_labels(self, y) -> List[str]:
135135

136136

137137
class BaseClassifier(ABC, _SklBaseEstimator, _SklClassifierMixin):
138+
139+
system_msg = "You are a text classifier."
140+
138141
def __init__(
139142
self,
140143
model: Optional[str], # model can initially be None for tunable estimators
@@ -263,7 +266,7 @@ def _get_prompt(self, x: str) -> dict:
263266
self.max_labels,
264267
template=self._get_prompt_template(),
265268
)
266-
return {"messages": prompt, "system_message": "You are a text classifier."}
269+
return {"messages": prompt, "system_message": self.system_msg}
267270

268271

269272
class BaseFewShotClassifier(BaseClassifier):

0 commit comments

Comments
 (0)