1- import json
21from typing import AsyncIterator , cast
32
43from ragna .core import PackageRequirement , RagnaException , Requirement , Source
54
6- from ._api import ApiAssistant
5+ from ._http_api import HttpApiAssistant
76
87
9- class AnthropicApiAssistant ( ApiAssistant ):
8+ class AnthropicAssistant ( HttpApiAssistant ):
109 _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
1110 _MODEL : str
1211
@@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
3635 + "</documents>"
3736 )
3837
39- async def _call_api (
40- self , prompt : str , sources : list [Source ], * , max_new_tokens : int
38+ async def answer (
39+ self , prompt : str , sources : list [Source ], * , max_new_tokens : int = 256
4140 ) -> AsyncIterator [str ]:
42- import httpx_sse
43-
4441 # See https://docs.anthropic.com/claude/reference/messages_post
4542 # See https://docs.anthropic.com/claude/reference/streaming
46- async with httpx_sse .aconnect_sse (
47- self ._client ,
43+ async for data in self ._stream_sse (
4844 "POST" ,
4945 "https://api.anthropic.com/v1/messages" ,
5046 headers = {
@@ -61,23 +57,19 @@ async def _call_api(
6157 "temperature" : 0.0 ,
6258 "stream" : True ,
6359 },
64- ) as event_source :
65- await self ._assert_api_call_is_success (event_source .response )
66-
67- async for sse in event_source .aiter_sse ():
68- data = json .loads (sse .data )
69- # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
70- if "error" in data :
71- raise RagnaException (data ["error" ].pop ("message" ), ** data ["error" ])
72- elif data ["type" ] == "message_stop" :
73- break
74- elif data ["type" ] != "content_block_delta" :
75- continue
60+ ):
61+ # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
62+ if "error" in data :
63+ raise RagnaException (data ["error" ].pop ("message" ), ** data ["error" ])
64+ elif data ["type" ] == "message_stop" :
65+ break
66+ elif data ["type" ] != "content_block_delta" :
67+ continue
7668
77- yield cast (str , data ["delta" ].pop ("text" ))
69+ yield cast (str , data ["delta" ].pop ("text" ))
7870
7971
80- class ClaudeOpus (AnthropicApiAssistant ):
72+ class ClaudeOpus (AnthropicAssistant ):
8173 """[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview)
8274
8375 !!! info "Required environment variables"
@@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant):
9284 _MODEL = "claude-3-opus-20240229"
9385
9486
95- class ClaudeSonnet (AnthropicApiAssistant ):
87+ class ClaudeSonnet (AnthropicAssistant ):
9688 """[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview)
9789
9890 !!! info "Required environment variables"
@@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant):
10799 _MODEL = "claude-3-sonnet-20240229"
108100
109101
110- class ClaudeHaiku (AnthropicApiAssistant ):
102+ class ClaudeHaiku (AnthropicAssistant ):
111103 """[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview)
112104
113105 !!! info "Required environment variables"
0 commit comments