11import openai
22from skllm .config import SKLLMConfig as _Config
33from time import sleep
4-
5-
6- def get_chat_completion (
7- messages : dict ,
8- key : str ,
9- org : str ,
10- model : str = "gpt-3.5-turbo" ,
11- max_retries : int = 3 ,
12- api = "openai" ,
13- ):
14- """Gets a chat completion from the OpenAI API.
15-
16- Parameters
17- ----------
18- messages : dict
19- input messages to use.
20- key : str
21- The OPEN AI key to use.
22- org : str
23- The OPEN AI organization ID to use.
24- model : str, optional
25- The OPEN AI model to use. Defaults to "gpt-3.5-turbo".
26- max_retries : int, optional
27- The maximum number of retries to use. Defaults to 3.
28- api : str
29- The API to use. Must be one of "openai" or "azure". Defaults to "openai".
30-
31- Returns
32- -------
33- completion : dict
34- """
35- if api == "openai" :
36- set_credentials (key , org )
37- model_dict = {"model" : model }
38- elif api == "azure" :
39- set_azure_credentials (key , org )
40- model_dict = {"engine" : model }
41- else :
42- raise ValueError ("Invalid API" )
43- error_msg = None
44- error_type = None
45- for _ in range (max_retries ):
46- try :
47- completion = openai .ChatCompletion .create (
48- temperature = 0.0 , messages = messages , ** model_dict
49- )
50- return completion
51- except Exception as e :
52- error_msg = str (e )
53- error_type = type (e ).__name__
54- sleep (3 )
55- print (
56- f"Could not obtain the completion after { max_retries } retries: `{ error_type } ::"
57- f" { error_msg } `"
58- )
59-
4+ from openai import OpenAI , AzureOpenAI
605
616def set_credentials (key : str , org : str ) -> None :
627 """Set the OpenAI key and organization.
@@ -68,12 +13,8 @@ def set_credentials(key: str, org: str) -> None:
6813 org : str
6914 The OPEN AI organization ID to use.
7015 """
71- openai .api_key = key
72- openai .organization = org
73- openai .api_type = "open_ai"
74- openai .api_version = None
75- openai .api_base = "https://api.openai.com/v1"
76-
16+ client = OpenAI (api_key = key , organization = org )
17+ return client
7718
7819def set_azure_credentials (key : str , org : str ) -> None :
7920 """Sets OpenAI credentials for Azure.
@@ -85,9 +26,6 @@ def set_azure_credentials(key: str, org: str) -> None:
8526 org : str
8627 The OpenAI (Azure) organization ID to use.
8728 """
88- if not openai .api_type or not openai .api_type .startswith ("azure" ):
89- openai .api_type = "azure"
90- openai .api_key = key
91- openai .organization = org
92- openai .api_base = _Config .get_azure_api_base ()
93- openai .api_version = _Config .get_azure_api_version ()
29+ client = AzureOpenAI (api_key = key , organization = org , api_version = _Config .get_azure_api_version (), azure_endpoint = _Config .get_azure_api_base ())
30+ return client
31+
0 commit comments