Skip to content

Commit 61dd410

Browse files
authored
fix: add oci blackout when rate limited (#228)
* fix: add oci blackout when rate limited * code review
1 parent fc455c4 commit 61dd410

File tree

7 files changed

+121
-14
lines changed

7 files changed

+121
-14
lines changed

skynet/env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def tobool(val: str | None):
111111

112112
# jobs
113113
job_timeout = int(os.environ.get('JOB_TIMEOUT', 60 * 5)) # 5 minutes default
114-
max_concurrency = int(os.environ.get('MAX_CONCURRENCY', 10))
114+
max_concurrency = int(os.environ.get('MAX_CONCURRENCY', 5))
115115

116116
# summaries
117117
summary_minimum_payload_length = int(os.environ.get('SUMMARY_MINIMUM_PAYLOAD_LENGTH', 100))
@@ -137,6 +137,7 @@ def tobool(val: str | None):
137137
oci_compartment_id = os.environ.get('OCI_COMPARTMENT_ID')
138138
oci_auth_type = os.environ.get('OCI_AUTH_TYPE', 'API_KEY')
139139
oci_config_profile = os.environ.get('OCI_CONFIG_PROFILE', 'DEFAULT')
140+
oci_blackout_fallback_duration = int(os.environ.get('OCI_BLACKOUT_FALLBACK_DURATION', 30))
140141
oci_available = oci_model_id and oci_service_endpoint and oci_compartment_id and oci_auth_type and oci_config_profile
141142
use_oci = oci_available and llama_path.startswith('oci://')
142143

skynet/modules/stt/streaming_whisper/app.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ async def websocket_endpoint(websocket: WebSocket, meeting_id: str, auth_token:
1717
while connection.connected:
1818
try:
1919
chunk = await websocket.receive_bytes()
20-
except WebSocketDisconnect as dc:
20+
except WebSocketDisconnect:
2121
log.info(f'Meeting {connection.meeting_id} has ended')
2222
await ws_connection_manager.disconnect(connection, already_closed=True)
23-
break
23+
break
2424
except WebSocketException as wserr:
2525
log.warning(f'Error on websocket {connection.meeting_id}. Error {wserr.__class__}: \n{wserr}')
2626
await ws_connection_manager.disconnect(connection)
2727
break
2828
except Exception as err:
29-
log.warning(f'Expected bytes, received something else, disconnecting {connection.meeting_id}. Error {err.__class__}: \n{err}')
29+
log.warning(
30+
f'Expected bytes, received something else, disconnecting {connection.meeting_id}. Error {err.__class__}: \n{err}'
31+
)
3032
await ws_connection_manager.disconnect(connection)
3133
break
3234
if len(chunk) == 1 and ord(b'' + chunk) == 0:

skynet/modules/stt/streaming_whisper/connection_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def connect(self, websocket: WebSocket, meeting_id: str, auth_token: str |
4141

4242
async def process(self, connection: MeetingConnection, chunk: bytes, chunk_timestamp: int):
4343
log.debug(f'Processing chunk for meeting {connection.meeting_id}')
44-
44+
4545
try:
4646
results = await connection.process(chunk, chunk_timestamp)
4747
await self.send(connection, results)
@@ -55,13 +55,15 @@ async def send(self, connection: MeetingConnection, results: list[utils.Transcri
5555
try:
5656
await connection.ws.send_json(result.model_dump())
5757
except WebSocketDisconnect as e:
58-
log.warning(f'Meeting {connection.meeting_id}: the connection was closed before sending all results: {e}')
58+
log.warning(
59+
f'Meeting {connection.meeting_id}: the connection was closed before sending all results: {e}'
60+
)
5961
await self.disconnect(connection, True)
6062
break
6163
except Exception as ex:
6264
log.error(f'Meeting {connection.meeting_id}: exception while sending transcription results {ex}')
6365

64-
async def disconnect(self, connection: MeetingConnection, already_closed = False):
66+
async def disconnect(self, connection: MeetingConnection, already_closed=False):
6567
try:
6668
self.connections.remove(connection)
6769
except ValueError:

skynet/modules/stt/vox/connection_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from fastapi import WebSocketDisconnect
22

33
from skynet.logs import get_logger
4-
from skynet.modules.stt.streaming_whisper.connection_manager import ConnectionManager as BaseConnectionManager, MeetingConnection
4+
from skynet.modules.stt.streaming_whisper.connection_manager import (
5+
ConnectionManager as BaseConnectionManager,
6+
MeetingConnection,
7+
)
58
from skynet.modules.stt.streaming_whisper.utils.utils import TranscriptionResponse
69

710
log = get_logger(__name__)
@@ -25,7 +28,9 @@ async def send(self, connection: MeetingConnection, results: list[TranscriptionR
2528
)
2629
log.debug(f'Participant {result.participant_id} result: {result.text}')
2730
except WebSocketDisconnect as e:
28-
log.warning(f'Session {connection.meeting_id}: the connection was closed before sending all results: {e}')
31+
log.warning(
32+
f'Session {connection.meeting_id}: the connection was closed before sending all results: {e}'
33+
)
2934
self.disconnect(connection, True)
3035
except Exception as ex:
3136
log.error(f'Session {connection.meeting_id}: exception while sending transcription results {ex}')

skynet/modules/ttt/llm_selector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def override_job_processor(job_id: str, processor: Processors) -> None:
3333
overriden_processors[job_id] = processor
3434

3535
@staticmethod
36-
def get_job_processor(customer_id: str, job_id: Optional[str] = None) -> Processors:
36+
def get_job_processor(customer_id: str, job_id: Optional[str] = None, oci_blackout: bool = False) -> Processors:
3737
if job_id and job_id in overriden_processors:
3838
return overriden_processors[job_id]
3939

@@ -50,7 +50,7 @@ def get_job_processor(customer_id: str, job_id: Optional[str] = None) -> Process
5050
if api_type == CredentialsType.LOCAL.value:
5151
return Processors.LOCAL
5252

53-
if oci_available:
53+
if oci_available and not oci_blackout:
5454
return Processors.OCI
5555

5656
return Processors.LOCAL
@@ -62,8 +62,9 @@ def select(
6262
max_completion_tokens: Optional[int] = None,
6363
temperature: Optional[float] = 0,
6464
stream: Optional[bool] = False,
65+
oci_blackout: bool = False,
6566
) -> BaseChatModel:
66-
processor = LLMSelector.get_job_processor(customer_id, job_id)
67+
processor = LLMSelector.get_job_processor(customer_id, job_id, oci_blackout)
6768
options = get_credentials(customer_id)
6869

6970
if processor == Processors.OPENAI:

skynet/modules/ttt/processor.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from datetime import datetime, timedelta, timezone
23
from operator import itemgetter
34
from typing import List, Optional
45

@@ -10,11 +11,13 @@
1011
from langchain_core.documents import Document
1112
from langchain_core.language_models.chat_models import BaseChatModel
1213
from langchain_core.output_parsers import StrOutputParser
14+
15+
from oci.exceptions import TransientServiceError
1316
from openai.types.chat import ChatCompletionMessageParam
1417

1518
from skynet.constants import response_prefix
1619

17-
from skynet.env import llama_n_ctx, modules, use_oci
20+
from skynet.env import llama_n_ctx, modules, oci_blackout_fallback_duration, use_oci
1821
from skynet.logs import get_logger
1922
from skynet.modules.monitoring import MAP_REDUCE_CHUNKING_COUNTER
2023
from skynet.modules.ttt.assistant.constants import assistant_rag_question_extractor
@@ -44,6 +47,31 @@
4447

4548
log = get_logger(__name__)
4649

50+
# Global OCI blackout state management
51+
_oci_blackout_until: Optional[datetime] = None
52+
53+
54+
def set_oci_blackout(duration_seconds: int) -> None:
55+
"""Set OCI blackout for the specified duration."""
56+
global _oci_blackout_until
57+
_oci_blackout_until = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
58+
log.warning(f"OCI blackout set until {_oci_blackout_until} ({duration_seconds} seconds)")
59+
60+
61+
def is_oci_blackout_active() -> bool:
62+
"""Check if OCI is currently in blackout period."""
63+
global _oci_blackout_until
64+
if _oci_blackout_until is None:
65+
return False
66+
67+
now = datetime.now(timezone.utc)
68+
if now >= _oci_blackout_until:
69+
_oci_blackout_until = None # Clear expired blackout
70+
log.info("OCI blackout period expired, resuming normal processing")
71+
return False
72+
73+
return True
74+
4775

4876
hint_type_to_prompt = {
4977
JobType.SUMMARY: {
@@ -201,7 +229,12 @@ async def process(job: Job) -> str:
201229
job_type = job.type
202230
customer_id = job.metadata.customer_id
203231

204-
llm = LLMSelector.select(customer_id, job_id=job.id, **{'max_completion_tokens': payload.max_completion_tokens})
232+
llm = LLMSelector.select(
233+
customer_id,
234+
job_id=job.id,
235+
oci_blackout=is_oci_blackout_active(),
236+
**{'max_completion_tokens': payload.max_completion_tokens},
237+
)
205238

206239
try:
207240
if job_type == JobType.ASSIST:
@@ -212,6 +245,18 @@ async def process(job: Job) -> str:
212245
result = await process_text(llm, payload)
213246
else:
214247
raise ValueError(f'Invalid job type {job_type}')
248+
except TransientServiceError as e:
249+
log.warning(f"Job {job.id} hit TransientServiceError: {e}")
250+
251+
# Set blackout using fallback duration
252+
blackout_duration = oci_blackout_fallback_duration
253+
log.info(f"TransientServiceError detected, setting {blackout_duration}s blackout")
254+
set_oci_blackout(blackout_duration)
255+
256+
# Switch current job to local processing
257+
LLMSelector.override_job_processor(job.id, Processors.LOCAL)
258+
return await process(job)
259+
215260
except Exception as e:
216261
log.warning(f"Job {job.id} failed: {e}")
217262

skynet/modules/ttt/processor_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
from oci.exceptions import TransientServiceError
4+
35
from skynet.modules.ttt.summaries.v1.models import DocumentMetadata, DocumentPayload, Job, JobType, Processors
46

57

@@ -249,3 +251,52 @@ async def test_process_with_oci_error_fallback(self, process_fixture):
249251

250252
assert LLMSelector.get_job_processor(job.metadata.customer_id, job.id) == Processors.LOCAL
251253
assert LLMSelector.select.call_count == 2
254+
255+
@pytest.mark.asyncio
256+
async def test_process_with_transient_service_error_blackout(self, process_fixture):
257+
'''Test that TransientServiceError triggers blackout and subsequent jobs use LOCAL processor.'''
258+
259+
from skynet.modules.ttt.llm_selector import LLMSelector
260+
from skynet.modules.ttt.processor import process
261+
262+
# Create TransientServiceError with circuit breaker message
263+
circuit_breaker_msg = (
264+
'Circuit "test-id" OPEN until 2025-09-04 13:23:43.823175+00:00 (12 failures, 17 sec remaining)'
265+
)
266+
transient_error = TransientServiceError(status=429, code='429', headers={}, message=circuit_breaker_msg)
267+
268+
process_fixture.patch(
269+
'skynet.modules.ttt.llm_selector.get_credentials',
270+
return_value={'type': 'OCI'},
271+
)
272+
process_fixture.patch('skynet.modules.ttt.llm_selector.oci_available', True)
273+
process_fixture.patch('skynet.modules.ttt.processor.use_oci', False) # allow fallback
274+
process_fixture.patch('skynet.modules.ttt.processor.summarize', side_effect=[transient_error, None, None])
275+
276+
job1 = Job(
277+
payload=DocumentPayload(text="First job"),
278+
metadata=DocumentMetadata(customer_id='test'),
279+
type=JobType.SUMMARY,
280+
)
281+
282+
job2 = Job(
283+
payload=DocumentPayload(text="Second job"),
284+
metadata=DocumentMetadata(customer_id='test'),
285+
type=JobType.SUMMARY,
286+
)
287+
288+
# First job should trigger TransientServiceError and set blackout
289+
await process(job1)
290+
assert LLMSelector.get_job_processor(job1.metadata.customer_id, job1.id) == Processors.LOCAL
291+
292+
# Second job should immediately go to LOCAL due to active blackout
293+
await process(job2)
294+
# Check that blackout causes LOCAL processor selection
295+
from skynet.modules.ttt.processor import is_oci_blackout_active
296+
297+
blackout_active = is_oci_blackout_active()
298+
assert blackout_active == True, "Blackout should be active after TransientServiceError"
299+
assert (
300+
LLMSelector.get_job_processor(job2.metadata.customer_id, job2.id, oci_blackout=blackout_active)
301+
== Processors.LOCAL
302+
)

0 commit comments

Comments
 (0)