Skip to content

Commit d1a3f4e

Browse files
Merge pull request #1766 from roboflow/fix/watchdog-modal-termination
Specify public stun servers
2 parents c617216 + c4e0bd3 commit d1a3f4e

File tree

5 files changed

+126
-50
lines changed

5 files changed

+126
-50
lines changed

inference/core/env.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,18 @@
765765
WEBRTC_MODAL_MODELS_PRELOAD_API_KEY = os.getenv("WEBRTC_MODAL_MODELS_PRELOAD_API_KEY")
766766
WEBRTC_MODAL_PRELOAD_MODELS = os.getenv("WEBRTC_MODAL_PRELOAD_MODELS")
767767
WEBRTC_MODAL_PRELOAD_HF_IDS = os.getenv("WEBRTC_MODAL_PRELOAD_HF_IDS")
768-
WEBRTC_MODAL_MIN_RAM_MB = int(os.getenv("WEBRTC_MODAL_MIN_RAM_MB", "4096"))
768+
try:
769+
WEBRTC_MODAL_MIN_CPU_CORES = int(os.getenv("WEBRTC_MODAL_MIN_CPU_CORES"))
770+
except (ValueError, TypeError):
771+
WEBRTC_MODAL_MIN_CPU_CORES = None
772+
try:
773+
WEBRTC_MODAL_MIN_RAM_MB = int(os.getenv("WEBRTC_MODAL_MIN_RAM_MB"))
774+
except (ValueError, TypeError):
775+
WEBRTC_MODAL_MIN_RAM_MB = None
776+
WEBRTC_MODAL_PUBLIC_STUN_SERVERS = os.getenv(
777+
"WEBRTC_MODAL_PUBLIC_STUN_SERVERS",
778+
"stun:stun.l.google.com:19302,stun:stun1.l.google.com:19302,stun:stun2.l.google.com:19302,stun:stun3.l.google.com:19302,stun:stun4.l.google.com:19302",
779+
)
769780

770781
HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED = str2bool(
771782
os.getenv("HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED", "True")

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 93 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import subprocess
55
from pathlib import Path
6-
from typing import Dict, Optional
6+
from typing import Callable, Dict, Optional
77

88
from inference.core import logger
99
from inference.core.env import (
@@ -31,6 +31,7 @@
3131
WEBRTC_MODAL_GCP_SECRET_NAME,
3232
WEBRTC_MODAL_IMAGE_NAME,
3333
WEBRTC_MODAL_IMAGE_TAG,
34+
WEBRTC_MODAL_MIN_CPU_CORES,
3435
WEBRTC_MODAL_MIN_RAM_MB,
3536
WEBRTC_MODAL_MODELS_PRELOAD_API_KEY,
3637
WEBRTC_MODAL_PRELOAD_HF_IDS,
@@ -39,6 +40,7 @@
3940
WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
4041
WEBRTC_MODAL_RTSP_PLACEHOLDER,
4142
WEBRTC_MODAL_RTSP_PLACEHOLDER_URL,
43+
WEBRTC_MODAL_SHUTDOWN_RESERVE,
4244
WEBRTC_MODAL_TOKEN_ID,
4345
WEBRTC_MODAL_TOKEN_SECRET,
4446
WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE,
@@ -126,6 +128,7 @@ def check_nvidia_smi_gpu() -> str:
126128
"buffer_containers": WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS,
127129
"scaledown_window": WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW,
128130
"memory": WEBRTC_MODAL_MIN_RAM_MB,
131+
"cpu": WEBRTC_MODAL_MIN_CPU_CORES,
129132
"timeout": WEBRTC_MODAL_FUNCTION_TIME_LIMIT,
130133
"enable_memory_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT,
131134
"max_inputs": WEBRTC_MODAL_FUNCTION_MAX_INPUTS,
@@ -167,17 +170,62 @@ def check_nvidia_smi_gpu() -> str:
167170
"WEBRTC_MODAL_FUNCTION_TIME_LIMIT": str(WEBRTC_MODAL_FUNCTION_TIME_LIMIT),
168171
"WEBRTC_MODAL_IMAGE_NAME": WEBRTC_MODAL_IMAGE_NAME,
169172
"WEBRTC_MODAL_IMAGE_TAG": WEBRTC_MODAL_IMAGE_TAG,
173+
"WEBRTC_MODAL_MIN_CPU_CORES": str(
174+
WEBRTC_MODAL_MIN_CPU_CORES if WEBRTC_MODAL_MIN_CPU_CORES else ""
175+
),
176+
"WEBRTC_MODAL_MIN_RAM_MB": str(
177+
WEBRTC_MODAL_MIN_RAM_MB if WEBRTC_MODAL_MIN_RAM_MB else ""
178+
),
170179
"WEBRTC_MODAL_MODELS_PRELOAD_API_KEY": (
171180
str(WEBRTC_MODAL_MODELS_PRELOAD_API_KEY)
172181
if WEBRTC_MODAL_MODELS_PRELOAD_API_KEY
173182
else ""
174183
),
175184
"WEBRTC_MODAL_RTSP_PLACEHOLDER": WEBRTC_MODAL_RTSP_PLACEHOLDER,
176185
"WEBRTC_MODAL_RTSP_PLACEHOLDER_URL": WEBRTC_MODAL_RTSP_PLACEHOLDER_URL,
186+
"WEBRTC_MODAL_SHUTDOWN_RESERVE": str(WEBRTC_MODAL_SHUTDOWN_RESERVE),
177187
},
178188
"volumes": {MODEL_CACHE_DIR: rfcache_volume},
179189
}
180190

191+
async def run_rtc_peer_connection_with_watchdog(
192+
webrtc_request: WebRTCWorkerRequest,
193+
send_answer: Callable[[WebRTCWorkerResult], None],
194+
model_manager: ModelManager,
195+
):
196+
from inference.core.interfaces.webrtc_worker.webrtc import (
197+
init_rtc_peer_connection_with_loop,
198+
)
199+
200+
watchdog = Watchdog(
201+
timeout_seconds=30,
202+
)
203+
204+
rtc_peer_connection_task = asyncio.create_task(
205+
init_rtc_peer_connection_with_loop(
206+
webrtc_request=webrtc_request,
207+
send_answer=send_answer,
208+
model_manager=model_manager,
209+
heartbeat_callback=watchdog.heartbeat,
210+
)
211+
)
212+
213+
def on_timeout():
214+
logger.info("Watchdog timeout reached")
215+
rtc_peer_connection_task.cancel()
216+
217+
watchdog.on_timeout = on_timeout
218+
watchdog.start()
219+
220+
try:
221+
await rtc_peer_connection_task
222+
except asyncio.CancelledError as exc:
223+
logger.info("WebRTC connection task was cancelled (%s)", exc)
224+
except Exception as exc:
225+
logger.error(exc)
226+
finally:
227+
watchdog.stop()
228+
181229
class RTCPeerConnectionModal:
182230
_model_manager: Optional[ModelManager] = modal.parameter(
183231
default=None, init=False
@@ -190,10 +238,6 @@ def rtc_peer_connection_modal(
190238
webrtc_request: WebRTCWorkerRequest,
191239
q: modal.Queue,
192240
):
193-
from inference.core.interfaces.webrtc_worker.webrtc import (
194-
init_rtc_peer_connection_with_loop,
195-
)
196-
197241
logger.info("*** Spawning %s:", self.__class__.__name__)
198242
logger.info("Running on %s", self._gpu)
199243
logger.info("Inference tag: %s", docker_tag)
@@ -223,7 +267,6 @@ def rtc_peer_connection_modal(
223267
logger.info("rtsp_url: %s", webrtc_request.rtsp_url)
224268
logger.info("processing_timeout: %s", webrtc_request.processing_timeout)
225269
logger.info("requested_plan: %s", webrtc_request.requested_plan)
226-
logger.info("requested_gpu: %s", webrtc_request.requested_gpu)
227270
logger.info("requested_region: %s", webrtc_request.requested_region)
228271
logger.info(
229272
"ICE servers: %s",
@@ -233,42 +276,28 @@ def rtc_peer_connection_modal(
233276
else []
234277
),
235278
)
279+
logger.info(
280+
"WEBRTC_MODAL_MIN_CPU_CORES: %s",
281+
WEBRTC_MODAL_MIN_CPU_CORES or "not set",
282+
)
283+
logger.info(
284+
"WEBRTC_MODAL_MIN_RAM_MB: %s", WEBRTC_MODAL_MIN_RAM_MB or "not set"
285+
)
236286
logger.info("MODAL_CLOUD_PROVIDER: %s", MODAL_CLOUD_PROVIDER)
237287
logger.info("MODAL_IMAGE_ID: %s", MODAL_IMAGE_ID)
238288
logger.info("MODAL_REGION: %s", MODAL_REGION)
239289
logger.info("MODAL_TASK_ID: %s", MODAL_TASK_ID)
240290
logger.info("MODAL_ENVIRONMENT: %s", MODAL_ENVIRONMENT)
241291
logger.info("MODAL_IDENTITY_TOKEN: %s", MODAL_IDENTITY_TOKEN)
242292

243-
try:
244-
current_loop = asyncio.get_running_loop()
245-
except RuntimeError:
246-
current_loop = asyncio.new_event_loop()
247-
asyncio.set_event_loop(current_loop)
248-
249-
def on_timeout():
250-
def shutdown():
251-
for task in asyncio.all_tasks():
252-
task.cancel()
253-
current_loop.stop()
254-
255-
current_loop.call_soon_threadsafe(shutdown)
256-
257-
watchdog = Watchdog(
258-
timeout_seconds=30,
259-
on_timeout=on_timeout,
260-
)
261-
262293
def send_answer(obj: WebRTCWorkerResult):
263294
logger.info("Sending webrtc answer")
264295
q.put(obj)
265-
watchdog.start()
266296

267297
if webrtc_request.processing_timeout == 0:
268298
error_msg = "Processing timeout is 0, skipping processing"
269299
logger.info(error_msg)
270300
send_answer(WebRTCWorkerResult(error_message=error_msg))
271-
watchdog.stop()
272301
return
273302
if (
274303
not webrtc_request.webrtc_offer
@@ -278,20 +307,15 @@ def send_answer(obj: WebRTCWorkerResult):
278307
error_msg = "Webrtc offer is missing, skipping processing"
279308
logger.info(error_msg)
280309
send_answer(WebRTCWorkerResult(error_message=error_msg))
281-
watchdog.stop()
282310
return
283311

284-
try:
285-
asyncio.run(
286-
init_rtc_peer_connection_with_loop(
287-
webrtc_request=webrtc_request,
288-
send_answer=send_answer,
289-
model_manager=self._model_manager,
290-
heartbeat_callback=watchdog.heartbeat,
291-
)
312+
asyncio.run(
313+
run_rtc_peer_connection_with_watchdog(
314+
webrtc_request=webrtc_request,
315+
send_answer=send_answer,
316+
model_manager=self._model_manager,
292317
)
293-
except Exception as exc:
294-
logger.error(exc)
318+
)
295319

296320
_exec_session_stopped = datetime.datetime.now()
297321
logger.info(
@@ -315,6 +339,8 @@ def send_answer(obj: WebRTCWorkerResult):
315339
video_source = "rtsp"
316340
elif not webrtc_request.webrtc_realtime_processing:
317341
video_source = "buffered browser stream"
342+
else:
343+
video_source = "realtime browser stream"
318344

319345
usage_collector.record_usage(
320346
source=workflow_id,
@@ -330,7 +356,6 @@ def send_answer(obj: WebRTCWorkerResult):
330356
).total_seconds(),
331357
)
332358
usage_collector.push_usage_payloads()
333-
watchdog.stop()
334359
logger.info("Function completed")
335360

336361
@modal.exit()
@@ -411,6 +436,9 @@ def start(self):
411436
def spawn_rtc_peer_connection_modal(
412437
webrtc_request: WebRTCWorkerRequest,
413438
) -> WebRTCWorkerResult:
439+
requested_gpu: Optional[str] = None
440+
requested_ram_mb: Optional[int] = None
441+
requested_cpu_cores: Optional[int] = None
414442
webrtc_plans: Optional[Dict[str, WebRTCPlan]] = (
415443
usage_collector._plan_details.get_webrtc_plans(
416444
api_key=webrtc_request.api_key
@@ -421,9 +449,11 @@ def spawn_rtc_peer_connection_modal(
421449
raise RoboflowAPIUnsuccessfulRequestError(
422450
f"Unknown requested plan {webrtc_request.requested_plan}, available plans: {', '.join(webrtc_plans.keys())}"
423451
)
424-
webrtc_request.requested_gpu = webrtc_plans[
425-
webrtc_request.requested_plan
426-
].gpu
452+
requested_gpu = webrtc_plans[webrtc_request.requested_plan].gpu
453+
requested_ram_mb = webrtc_plans[webrtc_request.requested_plan].ram_mb
454+
requested_cpu_cores = webrtc_plans[webrtc_request.requested_plan].cpu_cores
455+
456+
# TODO: requested_gpu is replaced with requested_plan
427457
if (
428458
webrtc_plans
429459
and not webrtc_request.requested_plan
@@ -435,6 +465,7 @@ def spawn_rtc_peer_connection_modal(
435465
f"Requested gpu {webrtc_request.requested_gpu} not associated with any plan, available gpus: {', '.join(gpu_to_plan.keys())}"
436466
)
437467
webrtc_request.requested_plan = gpu_to_plan[webrtc_request.requested_gpu]
468+
requested_gpu = webrtc_plans[webrtc_request.requested_plan].gpu
438469

439470
# https://modal.com/docs/reference/modal.Client#from_credentials
440471
client = modal.Client.from_credentials(
@@ -483,7 +514,7 @@ def spawn_rtc_peer_connection_modal(
483514
logger.info("Parametrized preload models: %s", WEBRTC_MODAL_PRELOAD_MODELS)
484515
preload_models = WEBRTC_MODAL_PRELOAD_MODELS
485516

486-
if webrtc_request.requested_gpu:
517+
if requested_gpu:
487518
RTCPeerConnectionModal = RTCPeerConnectionModalGPU
488519
else:
489520
RTCPeerConnectionModal = RTCPeerConnectionModalCPU
@@ -505,16 +536,16 @@ def spawn_rtc_peer_connection_modal(
505536
cls_with_options = deployed_cls.with_options(
506537
timeout=webrtc_request.processing_timeout,
507538
)
508-
if webrtc_request.requested_gpu is not None:
539+
if requested_gpu is not None:
509540
logger.info(
510541
"Spawning webrtc modal function with gpu %s",
511-
webrtc_request.requested_gpu,
542+
requested_gpu,
512543
)
513544
# Specify fallback GPU
514545
# TODO: with_options does not support gpu fallback
515546
# https://modal.com/docs/examples/gpu_fallbacks#set-fallback-gpus
516547
cls_with_options = cls_with_options.with_options(
517-
gpu=webrtc_request.requested_gpu,
548+
gpu=requested_gpu,
518549
)
519550
if webrtc_request.requested_region:
520551
logger.info(
@@ -524,6 +555,22 @@ def spawn_rtc_peer_connection_modal(
524555
cls_with_options = cls_with_options.with_options(
525556
region=webrtc_request.requested_region,
526557
)
558+
if requested_ram_mb is not None:
559+
logger.info(
560+
"Spawning webrtc modal function with ram %s",
561+
requested_ram_mb,
562+
)
563+
cls_with_options = cls_with_options.with_options(
564+
ram=requested_ram_mb,
565+
)
566+
if requested_cpu_cores is not None:
567+
logger.info(
568+
"Spawning webrtc modal function with cpu cores %s",
569+
requested_cpu_cores,
570+
)
571+
cls_with_options = cls_with_options.with_options(
572+
cpu=requested_cpu_cores,
573+
)
527574
rtc_modal_obj: RTCPeerConnectionModal = cls_with_options(
528575
preload_hf_ids=preload_hf_ids,
529576
preload_models=preload_models,

inference/core/interfaces/webrtc_worker/watchdog.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
import datetime
22
import threading
33
import time
4-
from typing import Callable
4+
from typing import Callable, Optional
55

66
from inference.core.logger import logger
77

88

99
class Watchdog:
10-
def __init__(self, timeout_seconds: int, on_timeout: Callable[[], None]):
10+
def __init__(
11+
self, timeout_seconds: int, on_timeout: Optional[Callable[[], None]] = None
12+
):
1113
self.timeout_seconds = timeout_seconds
1214
self.last_heartbeat = datetime.datetime.now()
13-
self.on_timeout = on_timeout
15+
self.on_timeout: Optional[Callable[[], None]] = on_timeout
1416
self._thread = threading.Thread(target=self._watchdog_thread)
1517
self._stopping = False
1618
self._last_log_ts = datetime.datetime.now()
1719
self._log_interval_seconds = 10
1820
self._heartbeats = 0
1921

2022
def start(self):
23+
if not self.on_timeout:
24+
raise ValueError(
25+
"on_timeout callback must be provided before starting the watchdog"
26+
)
2127
self._thread.start()
2228

2329
def stop(self):

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from inference.core import logger
2727
from inference.core.env import (
28+
WEBRTC_MODAL_PUBLIC_STUN_SERVERS,
2829
WEBRTC_MODAL_RTSP_PLACEHOLDER,
2930
WEBRTC_MODAL_RTSP_PLACEHOLDER_URL,
3031
WEBRTC_MODAL_SHUTDOWN_RESERVE,
@@ -814,6 +815,15 @@ async def init_rtc_peer_connection_with_loop(
814815
credential=ice_server.credential,
815816
)
816817
)
818+
# Always add public stun servers (if specified)
819+
if WEBRTC_MODAL_PUBLIC_STUN_SERVERS:
820+
for stun_server in WEBRTC_MODAL_PUBLIC_STUN_SERVERS.split(","):
821+
try:
822+
ice_servers.append(RTCIceServer(urls=stun_server.strip()))
823+
except Exception as e:
824+
logger.warning(
825+
"Failed to add public stun server '%s': %s", stun_server, e
826+
)
817827
else:
818828
ice_servers = None
819829
peer_connection = RTCPeerConnectionWithLoop(

inference/usage_tracking/plan_details.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
class WebRTCPlan(BaseModel):
1616
gpu: Optional[str] = None
17+
cpu_cores: Optional[int] = None
18+
ram_mb: Optional[int] = None
1719

1820

1921
class PlanDetails(SQLiteWrapper):

0 commit comments

Comments
 (0)