Skip to content

Commit 5491ec8

Browse files
authored
Merge pull request #118 from 00x808080/feature/anthropic-api
Add Anthropic API support
2 parents 35315db + 827cf94 commit 5491ec8

File tree

12 files changed

+644
-0
lines changed

12 files changed

+644
-0
lines changed

skllm/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
88
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
99
_GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"
10+
_ANTHROPIC_KEY_VAR = "SKLLM_CONFIG_ANTHROPIC_KEY"
1011
_GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH"
1112
_GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS"
1213
_GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE"
@@ -168,6 +169,28 @@ def get_gpt_url() -> Optional[str]:
168169
GPT URL.
169170
"""
170171
return os.environ.get(_GPT_URL_VAR, None)
172+
173+
@staticmethod
174+
def set_anthropic_key(key: str) -> None:
175+
"""Sets the Anthropic key.
176+
177+
Parameters
178+
----------
179+
key : str
180+
Anthropic key.
181+
"""
182+
os.environ[_ANTHROPIC_KEY_VAR] = key
183+
184+
@staticmethod
185+
def get_anthropic_key() -> Optional[str]:
186+
"""Gets the Anthropic key.
187+
188+
Returns
189+
-------
190+
Optional[str]
191+
Anthropic key.
192+
"""
193+
return os.environ.get(_ANTHROPIC_KEY_VAR, None)
171194

172195
@staticmethod
173196
def reset_gpt_url():

skllm/llm/anthropic/completion.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Dict, List, Optional
2+
from skllm.llm.anthropic.credentials import set_credentials
3+
from skllm.utils import retry
4+
5+
@retry(max_retries=3)
6+
def get_chat_completion(
7+
messages: List[Dict],
8+
key: str,
9+
model: str = "claude-3-haiku-20240307",
10+
max_tokens: int = 1000,
11+
temperature: float = 0.0,
12+
system: Optional[str] = None,
13+
json_response: bool = False,
14+
) -> dict:
15+
"""
16+
Gets a chat completion from the Anthropic Claude API using the Messages API.
17+
18+
Parameters
19+
----------
20+
messages : dict
21+
Input messages to use.
22+
key : str
23+
The Anthropic API key to use.
24+
model : str, optional
25+
The Claude model to use.
26+
max_tokens : int, optional
27+
Maximum tokens to generate.
28+
temperature : float, optional
29+
Sampling temperature.
30+
system : str, optional
31+
System message to set the assistant's behavior.
32+
json_response : bool, optional
33+
Whether to request a JSON-formatted response. Defaults to False.
34+
35+
Returns
36+
-------
37+
response : dict
38+
The completion response from the API.
39+
"""
40+
if not messages:
41+
raise ValueError("Messages list cannot be empty")
42+
if not isinstance(messages, list):
43+
raise TypeError("Messages must be a list")
44+
45+
client = set_credentials(key)
46+
47+
if json_response and system:
48+
system = f"{system.rstrip('.')}. Respond in JSON format."
49+
elif json_response:
50+
system = "Respond in JSON format."
51+
52+
formatted_messages = [
53+
{
54+
"role": "user", # Explicitly set role to "user"
55+
"content": [
56+
{
57+
"type": "text",
58+
"text": message.get("content", "")
59+
}
60+
]
61+
}
62+
for message in messages
63+
]
64+
65+
response = client.messages.create(
66+
model=model,
67+
max_tokens=max_tokens,
68+
temperature=temperature,
69+
system=system,
70+
messages=formatted_messages,
71+
)
72+
return response

skllm/llm/anthropic/credentials.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from anthropic import Anthropic
2+
3+
4+
def set_credentials(key: str) -> None:
5+
"""Set the Anthropic key.
6+
7+
Parameters
8+
----------
9+
key : str
10+
The Anthropic key to use.
11+
"""
12+
client = Anthropic(api_key=key)
13+
return client

skllm/llm/anthropic/mixin.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import Optional, Union, Any, List, Dict, Mapping
2+
from skllm.config import SKLLMConfig as _Config
3+
from skllm.llm.anthropic.completion import get_chat_completion
4+
from skllm.utils import extract_json_key
5+
from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin
6+
import json
7+
8+
9+
class ClaudeMixin:
10+
"""A mixin class that provides Claude API key to other classes."""
11+
12+
_prefer_json_output = False
13+
14+
def _set_keys(self, key: Optional[str] = None) -> None:
15+
"""Set the Claude API key."""
16+
self.key = key
17+
18+
def _get_claude_key(self) -> str:
19+
"""Get the Claude key from the class or config file."""
20+
key = self.key
21+
if key is None:
22+
key = _Config.get_anthropic_key()
23+
if key is None:
24+
raise RuntimeError("Claude API key was not found")
25+
return key
26+
27+
class ClaudeTextCompletionMixin(ClaudeMixin, BaseTextCompletionMixin):
28+
"""A mixin class that provides text completion capabilities using the Claude API."""
29+
30+
def _get_chat_completion(
31+
self,
32+
model: str,
33+
messages: Union[str, List[Dict[str, str]]],
34+
system_message: Optional[str] = None,
35+
**kwargs: Any,
36+
):
37+
"""Gets a chat completion from the Anthropic API.
38+
39+
Parameters
40+
----------
41+
model : str
42+
The model to use.
43+
messages : Union[str, List[Dict[str, str]]]
44+
input messages to use.
45+
system_message : Optional[str]
46+
A system message to use.
47+
**kwargs : Any
48+
placeholder.
49+
50+
Returns
51+
-------
52+
completion : dict
53+
"""
54+
if isinstance(messages, str):
55+
messages = [{"content": messages}]
56+
elif isinstance(messages, list):
57+
messages = [{"content": msg["content"]} for msg in messages]
58+
59+
completion = get_chat_completion(
60+
messages=messages,
61+
key=self._get_claude_key(),
62+
model=model,
63+
system=system_message,
64+
json_response=self._prefer_json_output,
65+
**kwargs,
66+
)
67+
return completion
68+
69+
def _convert_completion_to_str(self, completion: Mapping[str, Any]):
70+
"""Converts Claude API completion to string."""
71+
try:
72+
if hasattr(completion, 'content'):
73+
return completion.content[0].text
74+
return completion.get('content', [{}])[0].get('text', '')
75+
except Exception as e:
76+
print(f"Error converting completion to string: {str(e)}")
77+
return ""
78+
79+
class ClaudeClassifierMixin(ClaudeTextCompletionMixin, BaseClassifierMixin):
80+
"""A mixin class that provides classification capabilities using Claude API."""
81+
82+
_prefer_json_output = True
83+
84+
def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str:
85+
"""Extracts the label from a Claude API completion."""
86+
try:
87+
content = self._convert_completion_to_str(completion)
88+
if not self._prefer_json_output:
89+
return content.strip()
90+
91+
# Attempt to parse content as JSON and extract label
92+
try:
93+
data = json.loads(content)
94+
if "label" in data:
95+
return data["label"]
96+
except json.JSONDecodeError:
97+
pass
98+
return ""
99+
100+
except Exception as e:
101+
print(f"Error extracting label: {str(e)}")
102+
return ""
103+
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from skllm.models._base.classifier import (
2+
BaseFewShotClassifier,
3+
BaseDynamicFewShotClassifier,
4+
SingleLabelMixin,
5+
MultiLabelMixin,
6+
)
7+
from skllm.llm.anthropic.mixin import ClaudeClassifierMixin
8+
from skllm.models.gpt.vectorization import GPTVectorizer
9+
from skllm.models._base.vectorizer import BaseVectorizer
10+
from skllm.memory.base import IndexConstructor
11+
from typing import Optional
12+
13+
14+
class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin):
15+
"""Few-shot text classifier using Anthropic's Claude API for single-label classification tasks."""
16+
17+
def __init__(
18+
self,
19+
model: str = "claude-3-haiku-20240307",
20+
default_label: str = "Random",
21+
prompt_template: Optional[str] = None,
22+
key: Optional[str] = None,
23+
**kwargs,
24+
):
25+
"""
26+
Few-shot text classifier using Anthropic's Claude API.
27+
28+
Parameters
29+
----------
30+
model : str, optional
31+
model to use, by default "claude-3-haiku-20240307"
32+
default_label : str, optional
33+
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
34+
prompt_template : Optional[str], optional
35+
custom prompt template to use, by default None
36+
key : Optional[str], optional
37+
estimator-specific API key; if None, retrieved from the global config
38+
"""
39+
super().__init__(
40+
model=model,
41+
default_label=default_label,
42+
prompt_template=prompt_template,
43+
**kwargs,
44+
)
45+
self._set_keys(key)
46+
47+
48+
class MultiLabelFewShotClaudeClassifier(
49+
BaseFewShotClassifier, ClaudeClassifierMixin, MultiLabelMixin
50+
):
51+
"""Few-shot text classifier using Anthropic's Claude API for multi-label classification tasks."""
52+
53+
def __init__(
54+
self,
55+
model: str = "claude-3-haiku-20240307",
56+
default_label: str = "Random",
57+
max_labels: Optional[int] = 5,
58+
prompt_template: Optional[str] = None,
59+
key: Optional[str] = None,
60+
**kwargs,
61+
):
62+
"""
63+
Multi-label few-shot text classifier using Anthropic's Claude API.
64+
65+
Parameters
66+
----------
67+
model : str, optional
68+
model to use, by default "claude-3-haiku-20240307"
69+
default_label : str, optional
70+
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
71+
max_labels : Optional[int], optional
72+
maximum labels per sample, by default 5
73+
prompt_template : Optional[str], optional
74+
custom prompt template to use, by default None
75+
key : Optional[str], optional
76+
estimator-specific API key; if None, retrieved from the global config
77+
"""
78+
super().__init__(
79+
model=model,
80+
default_label=default_label,
81+
max_labels=max_labels,
82+
prompt_template=prompt_template,
83+
**kwargs,
84+
)
85+
self._set_keys(key)
86+
87+
88+
class DynamicFewShotClaudeClassifier(
89+
BaseDynamicFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin
90+
):
91+
"""
92+
Dynamic few-shot text classifier using Anthropic's Claude API for
93+
single-label classification tasks with dynamic example selection using GPT embeddings.
94+
"""
95+
96+
def __init__(
97+
self,
98+
model: str = "claude-3-haiku-20240307",
99+
default_label: str = "Random",
100+
prompt_template: Optional[str] = None,
101+
key: Optional[str] = None,
102+
n_examples: int = 3,
103+
memory_index: Optional[IndexConstructor] = None,
104+
vectorizer: Optional[BaseVectorizer] = None,
105+
metric: Optional[str] = "euclidean",
106+
**kwargs,
107+
):
108+
"""
109+
Dynamic few-shot text classifier using Anthropic's Claude API.
110+
For each sample, N closest examples are retrieved from the memory.
111+
112+
Parameters
113+
----------
114+
model : str, optional
115+
model to use, by default "claude-3-haiku-20240307"
116+
default_label : str, optional
117+
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
118+
prompt_template : Optional[str], optional
119+
custom prompt template to use, by default None
120+
key : Optional[str], optional
121+
estimator-specific API key; if None, retrieved from the global config
122+
n_examples : int, optional
123+
number of closest examples per class to be retrieved, by default 3
124+
memory_index : Optional[IndexConstructor], optional
125+
custom memory index, for details check `skllm.memory` submodule
126+
vectorizer : Optional[BaseVectorizer], optional
127+
scikit-llm vectorizer; if None, `GPTVectorizer` is used
128+
metric : Optional[str], optional
129+
metric used for similarity search, by default "euclidean"
130+
"""
131+
if vectorizer is None:
132+
vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key)
133+
super().__init__(
134+
model=model,
135+
default_label=default_label,
136+
prompt_template=prompt_template,
137+
n_examples=n_examples,
138+
memory_index=memory_index,
139+
vectorizer=vectorizer,
140+
metric=metric,
141+
)
142+
self._set_keys(key)

0 commit comments

Comments
 (0)