Skip to content

Commit fc455c4

Browse files
feat(whisper): Track connections individually not by lookup (#227)
* feat(whisper): Track connections in a list, not keyed by meeting_id * updated logic to ensure connection is handled properly * update vox connection manager to match new connection-only scheme * rename disconnect, remove send * cleanup naming * proper close * fix awaits and disconnect * handle exceptions on disconnections
1 parent eaf54d9 commit fc455c4

File tree

5 files changed

+76
-48
lines changed

5 files changed

+76
-48
lines changed
Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
1+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, WebSocketException
22

33
from skynet.logs import get_logger
44
from skynet.modules.stt.streaming_whisper.connection_manager import ConnectionManager
@@ -12,20 +12,25 @@
1212

1313
@app.websocket('/ws/{meeting_id}')
1414
async def websocket_endpoint(websocket: WebSocket, meeting_id: str, auth_token: str | None = None):
15-
await ws_connection_manager.connect(websocket, meeting_id, auth_token)
16-
try:
17-
while True:
15+
connection = await ws_connection_manager.connect(websocket, meeting_id, auth_token)
16+
if connection:
17+
while connection.connected:
1818
try:
1919
chunk = await websocket.receive_bytes()
20+
except WebSocketDisconnect as dc:
21+
log.info(f'Meeting {connection.meeting_id} has ended')
22+
await ws_connection_manager.disconnect(connection, already_closed=True)
23+
break
24+
except WebSocketException as wserr:
25+
log.warning(f'Error on websocket {connection.meeting_id}. Error {wserr.__class__}: \n{wserr}')
26+
await ws_connection_manager.disconnect(connection)
27+
break
2028
except Exception as err:
21-
log.warning(f'Expected bytes, received something else, disconnecting {meeting_id}. Error: \n{err}')
22-
ws_connection_manager.disconnect(meeting_id)
29+
log.warning(f'Expected bytes, received something else, disconnecting {connection.meeting_id}. Error {err.__class__}: \n{err}')
30+
await ws_connection_manager.disconnect(connection)
2331
break
2432
if len(chunk) == 1 and ord(b'' + chunk) == 0:
25-
log.info(f'Received disconnect message for {meeting_id}')
26-
ws_connection_manager.disconnect(meeting_id)
33+
log.info(f'Received disconnect message for {connection.meeting_id}')
34+
await ws_connection_manager.disconnect(connection)
2735
break
28-
await ws_connection_manager.process(meeting_id, chunk, utils.now())
29-
except WebSocketDisconnect:
30-
ws_connection_manager.disconnect(meeting_id)
31-
log.info(f'Meeting {meeting_id} has ended')
36+
await ws_connection_manager.process(connection, chunk, utils.now())

skynet/modules/stt/streaming_whisper/connection_manager.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515

1616
class ConnectionManager:
17-
connections: dict[str, MeetingConnection]
17+
connections: list[MeetingConnection]
1818
flush_audio_task: Task | None
1919

2020
def __init__(self):
21-
self.connections: dict[str, MeetingConnection] = {}
21+
self.connections: list[MeetingConnection] = []
2222
self.flush_audio_task = None
2323

2424
async def connect(self, websocket: WebSocket, meeting_id: str, auth_token: str | None):
@@ -29,37 +29,48 @@ async def connect(self, websocket: WebSocket, meeting_id: str, auth_token: str |
2929
await websocket.close(401, 'Bad JWT token')
3030
return
3131
await websocket.accept()
32-
self.connections[meeting_id] = MeetingConnection(websocket)
32+
connection = MeetingConnection(websocket, meeting_id)
33+
self.connections.append(connection)
3334
if self.flush_audio_task is None:
3435
loop = asyncio.get_running_loop()
3536
self.flush_audio_task = loop.create_task(self.flush_working_audio_worker())
3637
inc_ws_conn_count()
3738
log.info(f'Meeting with id {meeting_id} started. Ongoing meetings {len(self.connections)}')
3839

39-
async def process(self, meeting_id: str, chunk: bytes, chunk_timestamp: int):
40-
log.debug(f'Processing chunk for meeting {meeting_id}')
41-
if meeting_id not in self.connections:
42-
log.warning(f'No such meeting id {meeting_id}, the connection was probably closed.')
43-
return
44-
results = await self.connections[meeting_id].process(chunk, chunk_timestamp)
45-
await self.send(meeting_id, results)
40+
return connection
4641

47-
async def send(self, meeting_id: str, results: list[utils.TranscriptionResponse] | None):
42+
async def process(self, connection: MeetingConnection, chunk: bytes, chunk_timestamp: int):
43+
log.debug(f'Processing chunk for meeting {connection.meeting_id}')
44+
45+
try:
46+
results = await connection.process(chunk, chunk_timestamp)
47+
await self.send(connection, results)
48+
except Exception as e:
49+
log.error(f'Error processing chunk for meeting {connection.meeting_id}: {e}')
50+
await self.disconnect(connection)
51+
52+
async def send(self, connection: MeetingConnection, results: list[utils.TranscriptionResponse] | None):
4853
if results is not None:
4954
for result in results:
5055
try:
51-
await self.connections[meeting_id].ws.send_json(result.model_dump())
56+
await connection.ws.send_json(result.model_dump())
5257
except WebSocketDisconnect as e:
53-
log.warning(f'Meeting {meeting_id}: the connection was closed before sending all results: {e}')
54-
self.disconnect(meeting_id)
58+
log.warning(f'Meeting {connection.meeting_id}: the connection was closed before sending all results: {e}')
59+
await self.disconnect(connection, True)
60+
break
5561
except Exception as ex:
56-
log.error(f'Meeting {meeting_id}: exception while sending transcription results {ex}')
62+
log.error(f'Meeting {connection.meeting_id}: exception while sending transcription results {ex}')
5763

58-
def disconnect(self, meeting_id: str):
64+
async def disconnect(self, connection: MeetingConnection, already_closed = False):
5965
try:
60-
del self.connections[meeting_id]
61-
except KeyError:
62-
log.warning(f'The meeting {meeting_id} doesn\'t exist anymore.')
66+
self.connections.remove(connection)
67+
except ValueError:
68+
log.warning(f'The connection for meeting {connection.meeting_id} doesn\'t exist in the list anymore.')
69+
if not already_closed:
70+
await connection.close()
71+
else:
72+
# mark connection as disconnected
73+
connection.disconnect()
6374
dec_ws_conn_count()
6475

6576
async def flush_working_audio_worker(self):
@@ -69,15 +80,15 @@ async def flush_working_audio_worker(self):
6980
to the next utterance when the participant resumes speaking.
7081
"""
7182
while True:
72-
for meeting_id in self.connections:
73-
for participant in self.connections[meeting_id].participants:
74-
state = self.connections[meeting_id].participants[participant]
83+
for connection in self.connections:
84+
for participant in connection.participants:
85+
state = connection.participants[participant]
7586
diff = utils.now() - state.last_received_chunk
7687
log.debug(
77-
f'Participant {participant} in meeting {meeting_id} has been silent for {diff} ms and has {len(state.working_audio)} bytes of audio'
88+
f'Participant {participant} in meeting {connection.meeting_id} has been silent for {diff} ms and has {len(state.working_audio)} bytes of audio'
7889
)
7990
if diff > whisper_flush_interval and len(state.working_audio) > 0 and not state.is_transcribing:
80-
log.info(f'Forcing a transcription in meeting {meeting_id} for {participant}')
81-
results = await self.connections[meeting_id].force_transcription(participant)
82-
await self.send(meeting_id, results)
91+
log.info(f'Forcing a transcription in meeting {connection.meeting_id} for {participant}')
92+
results = await connection.force_transcription(participant)
93+
await self.send(connection, results)
8394
await asyncio.sleep(1)

skynet/modules/stt/streaming_whisper/meeting_connection.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,19 @@ class MeetingConnection:
2222
previous_transcription_store: List[List[int]]
2323
tokenizer: Tokenizer | None
2424
meeting_language: str | None
25+
meeting_id: str
26+
ws: WebSocket
27+
connected: True
2528

26-
def __init__(self, ws: WebSocket):
29+
def __init__(self, ws: WebSocket, meeting_id: str):
2730
self.participants = {}
2831
self.ws = ws
32+
self.meeting_id = meeting_id
2933
self.previous_transcription_tokens = []
3034
self.previous_transcription_store = []
3135
self.meeting_language = None
3236
self.tokenizer = None
37+
self.connected = True
3338

3439
async def update_initial_prompt(self, previous_payloads: list[utils.TranscriptionResponse]):
3540
for payload in previous_payloads:
@@ -68,3 +73,10 @@ async def force_transcription(self, participant_id: str):
6873
await self.update_initial_prompt(payloads)
6974
return payloads
7075
return None
76+
77+
def disconnect(self):
78+
self.connected = False
79+
80+
async def close(self):
81+
await self.ws.close()
82+
self.disconnect()

skynet/modules/stt/vox/app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def websocket_endpoint(websocket: WebSocket, auth_token: str | None = None
2323
decoder = PcmaDecoder()
2424
resampler = PcmResampler()
2525
session_id = utils.Uuid7().get()
26-
await ws_connection_manager.connect(websocket, session_id, auth_token)
26+
connection = await ws_connection_manager.connect(websocket, session_id, auth_token)
2727

2828
data_map = dict()
2929
resampler = None
@@ -64,7 +64,7 @@ async def websocket_endpoint(websocket: WebSocket, auth_token: str | None = None
6464

6565
task = asyncio.create_task(
6666
ws_connection_manager.process(
67-
session_id, participant['header'] + decoded_raw, media['timestamp']
67+
connection, participant['header'] + decoded_raw, media['timestamp']
6868
)
6969
)
7070

@@ -75,7 +75,7 @@ async def websocket_endpoint(websocket: WebSocket, auth_token: str | None = None
7575
participant['raw'] = b''
7676

7777
except WebSocketDisconnect:
78-
ws_connection_manager.disconnect(session_id)
78+
ws_connection_manager.disconnect(connection, True)
7979
data_map.clear()
8080
log.info(f'Session {session_id} has ended')
8181
break
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from fastapi import WebSocketDisconnect
22

33
from skynet.logs import get_logger
4-
from skynet.modules.stt.streaming_whisper.connection_manager import ConnectionManager as BaseConnectionManager
4+
from skynet.modules.stt.streaming_whisper.connection_manager import ConnectionManager as BaseConnectionManager, MeetingConnection
55
from skynet.modules.stt.streaming_whisper.utils.utils import TranscriptionResponse
66

77
log = get_logger(__name__)
88

99

1010
class ConnectionManager(BaseConnectionManager):
11-
async def send(self, session_id: str, results: list[TranscriptionResponse] | None):
11+
async def send(self, connection: MeetingConnection, results: list[TranscriptionResponse] | None):
1212
if results is None:
1313
return
1414

1515
final_results = [r for r in results if r.type == 'final']
1616
for result in final_results:
1717
try:
18-
await self.connections[session_id].ws.send_json(
18+
await connection.ws.send_json(
1919
{
2020
'timestamp': result.ts,
2121
'tag': result.participant_id,
@@ -25,7 +25,7 @@ async def send(self, session_id: str, results: list[TranscriptionResponse] | Non
2525
)
2626
log.debug(f'Participant {result.participant_id} result: {result.text}')
2727
except WebSocketDisconnect as e:
28-
log.warning(f'Session {session_id}: the connection was closed before sending all results: {e}')
29-
self.disconnect(session_id)
28+
log.warning(f'Session {connection.meeting_id}: the connection was closed before sending all results: {e}')
29+
self.disconnect(connection, True)
3030
except Exception as ex:
31-
log.error(f'Session {session_id}: exception while sending transcription results {ex}')
31+
log.error(f'Session {connection.meeting_id}: exception while sending transcription results {ex}')

0 commit comments

Comments
 (0)