3131 WEBRTC_MODAL_GCP_SECRET_NAME ,
3232 WEBRTC_MODAL_IMAGE_NAME ,
3333 WEBRTC_MODAL_IMAGE_TAG ,
34+ WEBRTC_MODAL_MIN_RAM_MB ,
3435 WEBRTC_MODAL_MODELS_PRELOAD_API_KEY ,
3536 WEBRTC_MODAL_PRELOAD_HF_IDS ,
3637 WEBRTC_MODAL_PRELOAD_MODELS ,
4849 WebRTCWorkerResult ,
4950)
5051from inference .core .interfaces .webrtc_worker .utils import (
52+ warmup_cuda ,
5153 workflow_contains_instant_model ,
5254 workflow_contains_preloaded_model ,
5355)
54- from inference .core .interfaces .webrtc_worker .webrtc import (
55- init_rtc_peer_connection_with_loop ,
56- )
56+ from inference .core .interfaces .webrtc_worker .watchdog import Watchdog
5757from inference .core .managers .base import ModelManager
5858from inference .core .registries .roboflow import RoboflowModelRegistry
5959from inference .core .roboflow_api import (
6262)
6363from inference .core .version import __version__
6464from inference .models .aliases import resolve_roboflow_model_alias
65- from inference .models .owlv2 .owlv2 import preload_owlv2_model
65+ from inference .models .owlv2 .owlv2 import PRELOADED_HF_MODELS , preload_owlv2_model
6666from inference .models .utils import ROBOFLOW_MODEL_TYPES
6767from inference .usage_tracking .collector import usage_collector
6868from inference .usage_tracking .plan_details import WebRTCPlan
@@ -125,6 +125,7 @@ def check_nvidia_smi_gpu() -> str:
125125 "min_containers" : WEBRTC_MODAL_FUNCTION_MIN_CONTAINERS ,
126126 "buffer_containers" : WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS ,
127127 "scaledown_window" : WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW ,
128+ "memory" : WEBRTC_MODAL_MIN_RAM_MB ,
128129 "timeout" : WEBRTC_MODAL_FUNCTION_TIME_LIMIT ,
129130 "enable_memory_snapshot" : WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT ,
130131 "max_inputs" : WEBRTC_MODAL_FUNCTION_MAX_INPUTS ,
@@ -152,6 +153,7 @@ def check_nvidia_smi_gpu() -> str:
152153 "ROBOFLOW_INTERNAL_SERVICE_SECRET" : ROBOFLOW_INTERNAL_SERVICE_SECRET ,
153154 "WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE" : WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE ,
154155 "TELEMETRY_USE_PERSISTENT_QUEUE" : "False" ,
156+ "TORCHINDUCTOR_COMPILE_THREADS" : "1" ,
155157 "WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS" : str (
156158 WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS
157159 ),
@@ -188,6 +190,10 @@ def rtc_peer_connection_modal(
188190 webrtc_request : WebRTCWorkerRequest ,
189191 q : modal .Queue ,
190192 ):
193+ from inference .core .interfaces .webrtc_worker .webrtc import (
194+ init_rtc_peer_connection_with_loop ,
195+ )
196+
191197 logger .info ("*** Spawning %s:" , self .__class__ .__name__ )
192198 logger .info ("Running on %s" , self ._gpu )
193199 logger .info ("Inference tag: %s" , docker_tag )
@@ -199,6 +205,9 @@ def rtc_peer_connection_modal(
199205 else ""
200206 ),
201207 )
208+ logger .info (
209+ "Preloaded hf models: %s" , ", " .join (PRELOADED_HF_MODELS .keys ())
210+ )
202211 _exec_session_started = datetime .datetime .now ()
203212 webrtc_request .processing_session_started = _exec_session_started
204213 logger .info (
@@ -231,14 +240,35 @@ def rtc_peer_connection_modal(
231240 logger .info ("MODAL_ENVIRONMENT: %s" , MODAL_ENVIRONMENT )
232241 logger .info ("MODAL_IDENTITY_TOKEN: %s" , MODAL_IDENTITY_TOKEN )
233242
243+ try :
244+ current_loop = asyncio .get_running_loop ()
245+ except RuntimeError :
246+ current_loop = asyncio .new_event_loop ()
247+ asyncio .set_event_loop (current_loop )
248+
249+ def on_timeout ():
250+ def shutdown ():
251+ for task in asyncio .all_tasks ():
252+ task .cancel ()
253+ current_loop .stop ()
254+
255+ current_loop .call_soon_threadsafe (shutdown )
256+
257+ watchdog = Watchdog (
258+ timeout_seconds = 30 ,
259+ on_timeout = on_timeout ,
260+ )
261+
234262 def send_answer (obj : WebRTCWorkerResult ):
235263 logger .info ("Sending webrtc answer" )
236264 q .put (obj )
265+ watchdog .start ()
237266
238267 if webrtc_request .processing_timeout == 0 :
239268 error_msg = "Processing timeout is 0, skipping processing"
240269 logger .info (error_msg )
241270 send_answer (WebRTCWorkerResult (error_message = error_msg ))
271+ watchdog .stop ()
242272 return
243273 if (
244274 not webrtc_request .webrtc_offer
@@ -248,15 +278,21 @@ def send_answer(obj: WebRTCWorkerResult):
248278 error_msg = "Webrtc offer is missing, skipping processing"
249279 logger .info (error_msg )
250280 send_answer (WebRTCWorkerResult (error_message = error_msg ))
281+ watchdog .stop ()
251282 return
252283
253- asyncio .run (
254- init_rtc_peer_connection_with_loop (
255- webrtc_request = webrtc_request ,
256- send_answer = send_answer ,
257- model_manager = self ._model_manager ,
284+ try :
285+ asyncio .run (
286+ init_rtc_peer_connection_with_loop (
287+ webrtc_request = webrtc_request ,
288+ send_answer = send_answer ,
289+ model_manager = self ._model_manager ,
290+ heartbeat_callback = watchdog .heartbeat ,
291+ )
258292 )
259- )
293+ except Exception as exc :
294+ logger .error (exc )
295+
260296 _exec_session_stopped = datetime .datetime .now ()
261297 logger .info (
262298 "WebRTC session stopped at %s" ,
@@ -294,6 +330,7 @@ def send_answer(obj: WebRTCWorkerResult):
294330 ).total_seconds (),
295331 )
296332 usage_collector .push_usage_payloads ()
333+ watchdog .stop ()
297334 logger .info ("Function completed" )
298335
299336 @modal .exit ()
@@ -335,6 +372,7 @@ class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
335372 # https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
336373 @modal .enter (snap = True )
337374 def start (self ):
375+ warmup_cuda (max_retries = 10 , retry_delay = 0.5 )
338376 self ._gpu = check_nvidia_smi_gpu ()
339377 logger .info ("Starting GPU container on %s" , self ._gpu )
340378 logger .info ("Preload hf ids: %s" , self .preload_hf_ids )
@@ -423,7 +461,6 @@ def spawn_rtc_peer_connection_modal(
423461 workflow_id = webrtc_request .workflow_configuration .workflow_id ,
424462 )
425463 )
426-
427464 tags = {"tag" : docker_tag }
428465 if workspace_id :
429466 tags ["workspace_id" ] = workspace_id
0 commit comments