Skip to content

Commit f993cab

Browse files
Fixed import to ensure that tests run
1 parent fd51a6e commit f993cab

File tree

4 files changed

+20
-16
lines changed

4 files changed

+20
-16
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,6 @@ test.py
162162
tmp.ipynb
163163
tmp.py
164164
*.pickle
165+
166+
# vscode
167+
.vscode/

tests/test_chatgpt.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import unittest
22
from unittest.mock import patch
33

4-
from skllm.openai.chatgpt import (
5-
construct_message,
6-
extract_json_key,
7-
get_chat_completion,
8-
)
4+
from skllm.openai.chatgpt import construct_message, get_chat_completion
5+
from skllm.utils import extract_json_key
96

107

118
class TestChatGPT(unittest.TestCase):
12-
139
@patch("skllm.openai.credentials.set_credentials")
1410
@patch("openai.ChatCompletion.create")
1511
def test_get_chat_completion(self, mock_create, mock_set_credentials):
@@ -21,9 +17,18 @@ def test_get_chat_completion(self, mock_create, mock_set_credentials):
2117

2218
result = get_chat_completion(messages, key, org, model)
2319

24-
self.assertTrue(mock_set_credentials.call_count <= 1, "set_credentials should be called at most once")
25-
self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception "
26-
"on the first call")
20+
self.assertTrue(
21+
mock_set_credentials.call_count <= 1,
22+
"set_credentials should be called at most once",
23+
)
24+
self.assertEqual(
25+
mock_create.call_count,
26+
2,
27+
(
28+
"ChatCompletion.create should be called twice due to an exception "
29+
"on the first call"
30+
),
31+
)
2732
self.assertEqual(result, "success")
2833

2934
def test_construct_message(self):
@@ -45,5 +50,5 @@ def test_extract_json_key(self):
4550
self.assertEqual(result_with_invalid_key, None)
4651

4752

48-
if __name__ == '__main__':
53+
if __name__ == "__main__":
4954
unittest.main()

tests/test_gpt_few_shot_clf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import numpy as np
77

8-
from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
9-
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
8+
from skllm import DynamicFewShotGPTClassifier, FewShotGPTClassifier
109

1110

1211
class TestFewShotGPTClassifier(unittest.TestCase):

tests/test_gpt_zero_shot_clf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55

66
import numpy as np
77

8-
from skllm.models.gpt_zero_shot_clf import (
9-
MultiLabelZeroShotGPTClassifier,
10-
ZeroShotGPTClassifier,
11-
)
8+
from skllm import MultiLabelZeroShotGPTClassifier, ZeroShotGPTClassifier
129

1310

1411
def _get_ret(label):

0 commit comments

Comments
 (0)