33import os
44import subprocess
55from pathlib import Path
6- from typing import Dict , Optional
6+ from typing import Callable , Dict , Optional
77
88from inference .core import logger
99from inference .core .env import (
3131 WEBRTC_MODAL_GCP_SECRET_NAME ,
3232 WEBRTC_MODAL_IMAGE_NAME ,
3333 WEBRTC_MODAL_IMAGE_TAG ,
34+ WEBRTC_MODAL_MIN_CPU_CORES ,
3435 WEBRTC_MODAL_MIN_RAM_MB ,
3536 WEBRTC_MODAL_MODELS_PRELOAD_API_KEY ,
3637 WEBRTC_MODAL_PRELOAD_HF_IDS ,
3940 WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME ,
4041 WEBRTC_MODAL_RTSP_PLACEHOLDER ,
4142 WEBRTC_MODAL_RTSP_PLACEHOLDER_URL ,
43+ WEBRTC_MODAL_SHUTDOWN_RESERVE ,
4244 WEBRTC_MODAL_TOKEN_ID ,
4345 WEBRTC_MODAL_TOKEN_SECRET ,
4446 WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE ,
@@ -126,6 +128,7 @@ def check_nvidia_smi_gpu() -> str:
126128 "buffer_containers" : WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS ,
127129 "scaledown_window" : WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW ,
128130 "memory" : WEBRTC_MODAL_MIN_RAM_MB ,
131+ "cpu" : WEBRTC_MODAL_MIN_CPU_CORES ,
129132 "timeout" : WEBRTC_MODAL_FUNCTION_TIME_LIMIT ,
130133 "enable_memory_snapshot" : WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT ,
131134 "max_inputs" : WEBRTC_MODAL_FUNCTION_MAX_INPUTS ,
@@ -167,17 +170,62 @@ def check_nvidia_smi_gpu() -> str:
167170 "WEBRTC_MODAL_FUNCTION_TIME_LIMIT" : str (WEBRTC_MODAL_FUNCTION_TIME_LIMIT ),
168171 "WEBRTC_MODAL_IMAGE_NAME" : WEBRTC_MODAL_IMAGE_NAME ,
169172 "WEBRTC_MODAL_IMAGE_TAG" : WEBRTC_MODAL_IMAGE_TAG ,
173+ "WEBRTC_MODAL_MIN_CPU_CORES" : str (
174+ WEBRTC_MODAL_MIN_CPU_CORES if WEBRTC_MODAL_MIN_CPU_CORES else ""
175+ ),
176+ "WEBRTC_MODAL_MIN_RAM_MB" : str (
177+ WEBRTC_MODAL_MIN_RAM_MB if WEBRTC_MODAL_MIN_RAM_MB else ""
178+ ),
170179 "WEBRTC_MODAL_MODELS_PRELOAD_API_KEY" : (
171180 str (WEBRTC_MODAL_MODELS_PRELOAD_API_KEY )
172181 if WEBRTC_MODAL_MODELS_PRELOAD_API_KEY
173182 else ""
174183 ),
175184 "WEBRTC_MODAL_RTSP_PLACEHOLDER" : WEBRTC_MODAL_RTSP_PLACEHOLDER ,
176185 "WEBRTC_MODAL_RTSP_PLACEHOLDER_URL" : WEBRTC_MODAL_RTSP_PLACEHOLDER_URL ,
186+ "WEBRTC_MODAL_SHUTDOWN_RESERVE" : str (WEBRTC_MODAL_SHUTDOWN_RESERVE ),
177187 },
178188 "volumes" : {MODEL_CACHE_DIR : rfcache_volume },
179189 }
180190
191+ async def run_rtc_peer_connection_with_watchdog (
192+ webrtc_request : WebRTCWorkerRequest ,
193+ send_answer : Callable [[WebRTCWorkerResult ], None ],
194+ model_manager : ModelManager ,
195+ ):
196+ from inference .core .interfaces .webrtc_worker .webrtc import (
197+ init_rtc_peer_connection_with_loop ,
198+ )
199+
200+ watchdog = Watchdog (
201+ timeout_seconds = 30 ,
202+ )
203+
204+ rtc_peer_connection_task = asyncio .create_task (
205+ init_rtc_peer_connection_with_loop (
206+ webrtc_request = webrtc_request ,
207+ send_answer = send_answer ,
208+ model_manager = model_manager ,
209+ heartbeat_callback = watchdog .heartbeat ,
210+ )
211+ )
212+
213+ def on_timeout ():
214+ logger .info ("Watchdog timeout reached" )
215+ rtc_peer_connection_task .cancel ()
216+
217+ watchdog .on_timeout = on_timeout
218+ watchdog .start ()
219+
220+ try :
221+ await rtc_peer_connection_task
222+ except asyncio .CancelledError as exc :
223+ logger .info ("WebRTC connection task was cancelled (%s)" , exc )
224+ except Exception as exc :
225+ logger .error (exc )
226+ finally :
227+ watchdog .stop ()
228+
181229 class RTCPeerConnectionModal :
182230 _model_manager : Optional [ModelManager ] = modal .parameter (
183231 default = None , init = False
@@ -190,10 +238,6 @@ def rtc_peer_connection_modal(
190238 webrtc_request : WebRTCWorkerRequest ,
191239 q : modal .Queue ,
192240 ):
193- from inference .core .interfaces .webrtc_worker .webrtc import (
194- init_rtc_peer_connection_with_loop ,
195- )
196-
197241 logger .info ("*** Spawning %s:" , self .__class__ .__name__ )
198242 logger .info ("Running on %s" , self ._gpu )
199243 logger .info ("Inference tag: %s" , docker_tag )
@@ -223,7 +267,6 @@ def rtc_peer_connection_modal(
223267 logger .info ("rtsp_url: %s" , webrtc_request .rtsp_url )
224268 logger .info ("processing_timeout: %s" , webrtc_request .processing_timeout )
225269 logger .info ("requested_plan: %s" , webrtc_request .requested_plan )
226- logger .info ("requested_gpu: %s" , webrtc_request .requested_gpu )
227270 logger .info ("requested_region: %s" , webrtc_request .requested_region )
228271 logger .info (
229272 "ICE servers: %s" ,
@@ -233,42 +276,28 @@ def rtc_peer_connection_modal(
233276 else []
234277 ),
235278 )
279+ logger .info (
280+ "WEBRTC_MODAL_MIN_CPU_CORES: %s" ,
281+ WEBRTC_MODAL_MIN_CPU_CORES or "not set" ,
282+ )
283+ logger .info (
284+ "WEBRTC_MODAL_MIN_RAM_MB: %s" , WEBRTC_MODAL_MIN_RAM_MB or "not set"
285+ )
236286 logger .info ("MODAL_CLOUD_PROVIDER: %s" , MODAL_CLOUD_PROVIDER )
237287 logger .info ("MODAL_IMAGE_ID: %s" , MODAL_IMAGE_ID )
238288 logger .info ("MODAL_REGION: %s" , MODAL_REGION )
239289 logger .info ("MODAL_TASK_ID: %s" , MODAL_TASK_ID )
240290 logger .info ("MODAL_ENVIRONMENT: %s" , MODAL_ENVIRONMENT )
241291 logger .info ("MODAL_IDENTITY_TOKEN: %s" , MODAL_IDENTITY_TOKEN )
242292
243- try :
244- current_loop = asyncio .get_running_loop ()
245- except RuntimeError :
246- current_loop = asyncio .new_event_loop ()
247- asyncio .set_event_loop (current_loop )
248-
249- def on_timeout ():
250- def shutdown ():
251- for task in asyncio .all_tasks ():
252- task .cancel ()
253- current_loop .stop ()
254-
255- current_loop .call_soon_threadsafe (shutdown )
256-
257- watchdog = Watchdog (
258- timeout_seconds = 30 ,
259- on_timeout = on_timeout ,
260- )
261-
262293 def send_answer (obj : WebRTCWorkerResult ):
263294 logger .info ("Sending webrtc answer" )
264295 q .put (obj )
265- watchdog .start ()
266296
267297 if webrtc_request .processing_timeout == 0 :
268298 error_msg = "Processing timeout is 0, skipping processing"
269299 logger .info (error_msg )
270300 send_answer (WebRTCWorkerResult (error_message = error_msg ))
271- watchdog .stop ()
272301 return
273302 if (
274303 not webrtc_request .webrtc_offer
@@ -278,20 +307,15 @@ def send_answer(obj: WebRTCWorkerResult):
278307 error_msg = "Webrtc offer is missing, skipping processing"
279308 logger .info (error_msg )
280309 send_answer (WebRTCWorkerResult (error_message = error_msg ))
281- watchdog .stop ()
282310 return
283311
284- try :
285- asyncio .run (
286- init_rtc_peer_connection_with_loop (
287- webrtc_request = webrtc_request ,
288- send_answer = send_answer ,
289- model_manager = self ._model_manager ,
290- heartbeat_callback = watchdog .heartbeat ,
291- )
312+ asyncio .run (
313+ run_rtc_peer_connection_with_watchdog (
314+ webrtc_request = webrtc_request ,
315+ send_answer = send_answer ,
316+ model_manager = self ._model_manager ,
292317 )
293- except Exception as exc :
294- logger .error (exc )
318+ )
295319
296320 _exec_session_stopped = datetime .datetime .now ()
297321 logger .info (
@@ -315,6 +339,8 @@ def send_answer(obj: WebRTCWorkerResult):
315339 video_source = "rtsp"
316340 elif not webrtc_request .webrtc_realtime_processing :
317341 video_source = "buffered browser stream"
342+ else :
343+ video_source = "realtime browser stream"
318344
319345 usage_collector .record_usage (
320346 source = workflow_id ,
@@ -330,7 +356,6 @@ def send_answer(obj: WebRTCWorkerResult):
330356 ).total_seconds (),
331357 )
332358 usage_collector .push_usage_payloads ()
333- watchdog .stop ()
334359 logger .info ("Function completed" )
335360
336361 @modal .exit ()
@@ -411,6 +436,9 @@ def start(self):
411436 def spawn_rtc_peer_connection_modal (
412437 webrtc_request : WebRTCWorkerRequest ,
413438 ) -> WebRTCWorkerResult :
439+ requested_gpu : Optional [str ] = None
440+ requested_ram_mb : Optional [int ] = None
441+ requested_cpu_cores : Optional [int ] = None
414442 webrtc_plans : Optional [Dict [str , WebRTCPlan ]] = (
415443 usage_collector ._plan_details .get_webrtc_plans (
416444 api_key = webrtc_request .api_key
@@ -421,9 +449,11 @@ def spawn_rtc_peer_connection_modal(
421449 raise RoboflowAPIUnsuccessfulRequestError (
422450 f"Unknown requested plan { webrtc_request .requested_plan } , available plans: { ', ' .join (webrtc_plans .keys ())} "
423451 )
424- webrtc_request .requested_gpu = webrtc_plans [
425- webrtc_request .requested_plan
426- ].gpu
452+ requested_gpu = webrtc_plans [webrtc_request .requested_plan ].gpu
453+ requested_ram_mb = webrtc_plans [webrtc_request .requested_plan ].ram_mb
454+ requested_cpu_cores = webrtc_plans [webrtc_request .requested_plan ].cpu_cores
455+
456+ # TODO: requested_gpu is replaced with requested_plan
427457 if (
428458 webrtc_plans
429459 and not webrtc_request .requested_plan
@@ -435,6 +465,7 @@ def spawn_rtc_peer_connection_modal(
435465 f"Requested gpu { webrtc_request .requested_gpu } not associated with any plan, available gpus: { ', ' .join (gpu_to_plan .keys ())} "
436466 )
437467 webrtc_request .requested_plan = gpu_to_plan [webrtc_request .requested_gpu ]
468+ requested_gpu = webrtc_plans [webrtc_request .requested_plan ].gpu
438469
439470 # https://modal.com/docs/reference/modal.Client#from_credentials
440471 client = modal .Client .from_credentials (
@@ -483,7 +514,7 @@ def spawn_rtc_peer_connection_modal(
483514 logger .info ("Parametrized preload models: %s" , WEBRTC_MODAL_PRELOAD_MODELS )
484515 preload_models = WEBRTC_MODAL_PRELOAD_MODELS
485516
486- if webrtc_request . requested_gpu :
517+ if requested_gpu :
487518 RTCPeerConnectionModal = RTCPeerConnectionModalGPU
488519 else :
489520 RTCPeerConnectionModal = RTCPeerConnectionModalCPU
@@ -505,16 +536,16 @@ def spawn_rtc_peer_connection_modal(
505536 cls_with_options = deployed_cls .with_options (
506537 timeout = webrtc_request .processing_timeout ,
507538 )
508- if webrtc_request . requested_gpu is not None :
539+ if requested_gpu is not None :
509540 logger .info (
510541 "Spawning webrtc modal function with gpu %s" ,
511- webrtc_request . requested_gpu ,
542+ requested_gpu ,
512543 )
513544 # Specify fallback GPU
514545 # TODO: with_options does not support gpu fallback
515546 # https://modal.com/docs/examples/gpu_fallbacks#set-fallback-gpus
516547 cls_with_options = cls_with_options .with_options (
517- gpu = webrtc_request . requested_gpu ,
548+ gpu = requested_gpu ,
518549 )
519550 if webrtc_request .requested_region :
520551 logger .info (
@@ -524,6 +555,22 @@ def spawn_rtc_peer_connection_modal(
524555 cls_with_options = cls_with_options .with_options (
525556 region = webrtc_request .requested_region ,
526557 )
558+ if requested_ram_mb is not None :
559+ logger .info (
560+ "Spawning webrtc modal function with ram %s" ,
561+ requested_ram_mb ,
562+ )
563+ cls_with_options = cls_with_options .with_options (
564+ ram = requested_ram_mb ,
565+ )
566+ if requested_cpu_cores is not None :
567+ logger .info (
568+ "Spawning webrtc modal function with cpu cores %s" ,
569+ requested_cpu_cores ,
570+ )
571+ cls_with_options = cls_with_options .with_options (
572+ cpu = requested_cpu_cores ,
573+ )
527574 rtc_modal_obj : RTCPeerConnectionModal = cls_with_options (
528575 preload_hf_ids = preload_hf_ids ,
529576 preload_models = preload_models ,
0 commit comments