Skip to content

Commit 0fc6cf1

Browse files
Merge branch 'main' into feature/inference-exp-jp-5
2 parents 3a5e7d5 + d9dfeda commit 0fc6cf1

File tree

11 files changed

+2358
-80
lines changed

11 files changed

+2358
-80
lines changed

inference/core/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,7 @@
765765
WEBRTC_MODAL_MODELS_PRELOAD_API_KEY = os.getenv("WEBRTC_MODAL_MODELS_PRELOAD_API_KEY")
766766
WEBRTC_MODAL_PRELOAD_MODELS = os.getenv("WEBRTC_MODAL_PRELOAD_MODELS")
767767
WEBRTC_MODAL_PRELOAD_HF_IDS = os.getenv("WEBRTC_MODAL_PRELOAD_HF_IDS")
768+
WEBRTC_MODAL_MIN_RAM_MB = int(os.getenv("WEBRTC_MODAL_MIN_RAM_MB", "4096"))
768769

769770
HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED = str2bool(
770771
os.getenv("HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED", "True")

inference/core/interfaces/stream/inference_pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,10 +596,9 @@ def init_with_workflow(
596596
named_workflow_specified = (workspace_name is not None) and (
597597
workflow_id is not None
598598
)
599-
if not (named_workflow_specified != (workflow_specification is not None)):
599+
if not named_workflow_specified and not workflow_specification:
600600
raise ValueError(
601-
"Parameters (`workspace_name`, `workflow_id`) can be used mutually exclusive with "
602-
"`workflow_specification`, but at least one must be set."
601+
"Either (`workspace_name`, `workflow_id`) or `workflow_specification` must be provided."
603602
)
604603
try:
605604
from inference.core.interfaces.stream.model_handlers.workflows import (

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
WEBRTC_MODAL_GCP_SECRET_NAME,
3232
WEBRTC_MODAL_IMAGE_NAME,
3333
WEBRTC_MODAL_IMAGE_TAG,
34+
WEBRTC_MODAL_MIN_RAM_MB,
3435
WEBRTC_MODAL_MODELS_PRELOAD_API_KEY,
3536
WEBRTC_MODAL_PRELOAD_HF_IDS,
3637
WEBRTC_MODAL_PRELOAD_MODELS,
@@ -48,12 +49,11 @@
4849
WebRTCWorkerResult,
4950
)
5051
from inference.core.interfaces.webrtc_worker.utils import (
52+
warmup_cuda,
5153
workflow_contains_instant_model,
5254
workflow_contains_preloaded_model,
5355
)
54-
from inference.core.interfaces.webrtc_worker.webrtc import (
55-
init_rtc_peer_connection_with_loop,
56-
)
56+
from inference.core.interfaces.webrtc_worker.watchdog import Watchdog
5757
from inference.core.managers.base import ModelManager
5858
from inference.core.registries.roboflow import RoboflowModelRegistry
5959
from inference.core.roboflow_api import (
@@ -62,7 +62,7 @@
6262
)
6363
from inference.core.version import __version__
6464
from inference.models.aliases import resolve_roboflow_model_alias
65-
from inference.models.owlv2.owlv2 import preload_owlv2_model
65+
from inference.models.owlv2.owlv2 import PRELOADED_HF_MODELS, preload_owlv2_model
6666
from inference.models.utils import ROBOFLOW_MODEL_TYPES
6767
from inference.usage_tracking.collector import usage_collector
6868
from inference.usage_tracking.plan_details import WebRTCPlan
@@ -125,6 +125,7 @@ def check_nvidia_smi_gpu() -> str:
125125
"min_containers": WEBRTC_MODAL_FUNCTION_MIN_CONTAINERS,
126126
"buffer_containers": WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS,
127127
"scaledown_window": WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW,
128+
"memory": WEBRTC_MODAL_MIN_RAM_MB,
128129
"timeout": WEBRTC_MODAL_FUNCTION_TIME_LIMIT,
129130
"enable_memory_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT,
130131
"max_inputs": WEBRTC_MODAL_FUNCTION_MAX_INPUTS,
@@ -152,6 +153,7 @@ def check_nvidia_smi_gpu() -> str:
152153
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
153154
"WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE": WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE,
154155
"TELEMETRY_USE_PERSISTENT_QUEUE": "False",
156+
"TORCHINDUCTOR_COMPILE_THREADS": "1",
155157
"WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS": str(
156158
WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS
157159
),
@@ -188,6 +190,10 @@ def rtc_peer_connection_modal(
188190
webrtc_request: WebRTCWorkerRequest,
189191
q: modal.Queue,
190192
):
193+
from inference.core.interfaces.webrtc_worker.webrtc import (
194+
init_rtc_peer_connection_with_loop,
195+
)
196+
191197
logger.info("*** Spawning %s:", self.__class__.__name__)
192198
logger.info("Running on %s", self._gpu)
193199
logger.info("Inference tag: %s", docker_tag)
@@ -199,6 +205,9 @@ def rtc_peer_connection_modal(
199205
else ""
200206
),
201207
)
208+
logger.info(
209+
"Preloaded hf models: %s", ", ".join(PRELOADED_HF_MODELS.keys())
210+
)
202211
_exec_session_started = datetime.datetime.now()
203212
webrtc_request.processing_session_started = _exec_session_started
204213
logger.info(
@@ -231,14 +240,35 @@ def rtc_peer_connection_modal(
231240
logger.info("MODAL_ENVIRONMENT: %s", MODAL_ENVIRONMENT)
232241
logger.info("MODAL_IDENTITY_TOKEN: %s", MODAL_IDENTITY_TOKEN)
233242

243+
try:
244+
current_loop = asyncio.get_running_loop()
245+
except RuntimeError:
246+
current_loop = asyncio.new_event_loop()
247+
asyncio.set_event_loop(current_loop)
248+
249+
def on_timeout():
250+
def shutdown():
251+
for task in asyncio.all_tasks():
252+
task.cancel()
253+
current_loop.stop()
254+
255+
current_loop.call_soon_threadsafe(shutdown)
256+
257+
watchdog = Watchdog(
258+
timeout_seconds=30,
259+
on_timeout=on_timeout,
260+
)
261+
234262
def send_answer(obj: WebRTCWorkerResult):
235263
logger.info("Sending webrtc answer")
236264
q.put(obj)
265+
watchdog.start()
237266

238267
if webrtc_request.processing_timeout == 0:
239268
error_msg = "Processing timeout is 0, skipping processing"
240269
logger.info(error_msg)
241270
send_answer(WebRTCWorkerResult(error_message=error_msg))
271+
watchdog.stop()
242272
return
243273
if (
244274
not webrtc_request.webrtc_offer
@@ -248,15 +278,21 @@ def send_answer(obj: WebRTCWorkerResult):
248278
error_msg = "Webrtc offer is missing, skipping processing"
249279
logger.info(error_msg)
250280
send_answer(WebRTCWorkerResult(error_message=error_msg))
281+
watchdog.stop()
251282
return
252283

253-
asyncio.run(
254-
init_rtc_peer_connection_with_loop(
255-
webrtc_request=webrtc_request,
256-
send_answer=send_answer,
257-
model_manager=self._model_manager,
284+
try:
285+
asyncio.run(
286+
init_rtc_peer_connection_with_loop(
287+
webrtc_request=webrtc_request,
288+
send_answer=send_answer,
289+
model_manager=self._model_manager,
290+
heartbeat_callback=watchdog.heartbeat,
291+
)
258292
)
259-
)
293+
except Exception as exc:
294+
logger.error(exc)
295+
260296
_exec_session_stopped = datetime.datetime.now()
261297
logger.info(
262298
"WebRTC session stopped at %s",
@@ -294,6 +330,7 @@ def send_answer(obj: WebRTCWorkerResult):
294330
).total_seconds(),
295331
)
296332
usage_collector.push_usage_payloads()
333+
watchdog.stop()
297334
logger.info("Function completed")
298335

299336
@modal.exit()
@@ -335,6 +372,7 @@ class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
335372
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
336373
@modal.enter(snap=True)
337374
def start(self):
375+
warmup_cuda(max_retries=10, retry_delay=0.5)
338376
self._gpu = check_nvidia_smi_gpu()
339377
logger.info("Starting GPU container on %s", self._gpu)
340378
logger.info("Preload hf ids: %s", self.preload_hf_ids)
@@ -423,7 +461,6 @@ def spawn_rtc_peer_connection_modal(
423461
workflow_id=webrtc_request.workflow_configuration.workflow_id,
424462
)
425463
)
426-
427464
tags = {"tag": docker_tag}
428465
if workspace_id:
429466
tags["workspace_id"] = workspace_id

inference/core/interfaces/webrtc_worker/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import ctypes
12
import datetime
23
import logging
4+
import time
35
from typing import Any, Dict, List, Optional, Tuple, Union
46

57
import cv2 as cv
@@ -181,3 +183,29 @@ def workflow_contains_preloaded_model(
181183
if model_id in preload_models or resolved_model_id in preload_models:
182184
return True
183185
return False
186+
187+
188+
def warmup_cuda(
189+
max_retries: int = 10,
190+
retry_delay: float = 0.5,
191+
):
192+
cu = ctypes.CDLL("libcuda.so.1")
193+
194+
for attempt in range(max_retries):
195+
rc = cu.cuInit(0)
196+
197+
if rc == 0:
198+
break
199+
else:
200+
if attempt < max_retries - 1:
201+
logger.warning(
202+
"cuInit failed on attempt %s/%s with code %s, retrying...",
203+
attempt + 1,
204+
max_retries,
205+
rc,
206+
)
207+
time.sleep(retry_delay)
208+
else:
209+
raise RuntimeError(f"CUDA initialization failed after {max_retries} attempts")
210+
211+
logger.info("CUDA initialization succeeded")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import datetime
2+
import threading
3+
import time
4+
from typing import Callable
5+
6+
from inference.core.logger import logger
7+
8+
9+
class Watchdog:
10+
def __init__(self, timeout_seconds: int, on_timeout: Callable[[], None]):
11+
self.timeout_seconds = timeout_seconds
12+
self.last_heartbeat = datetime.datetime.now()
13+
self.on_timeout = on_timeout
14+
self._thread = threading.Thread(target=self._watchdog_thread)
15+
self._stopping = False
16+
self._last_log_ts = datetime.datetime.now()
17+
self._log_interval_seconds = 10
18+
self._heartbeats = 0
19+
20+
def start(self):
21+
self._thread.start()
22+
23+
def stop(self):
24+
self._stopping = True
25+
if self._thread.is_alive():
26+
self._thread.join()
27+
28+
def _watchdog_thread(self):
29+
while not self._stopping:
30+
if not self.is_alive():
31+
logger.error("Watchdog timeout reached")
32+
self.on_timeout()
33+
break
34+
time.sleep(0.1)
35+
logger.info("Watchdog stopped")
36+
37+
def heartbeat(self):
38+
self.last_heartbeat = datetime.datetime.now()
39+
self._heartbeats += 1
40+
if (
41+
datetime.datetime.now() - self._last_log_ts
42+
).total_seconds() > self._log_interval_seconds:
43+
logger.info("Watchdog heartbeat (%s since last)", self._heartbeats)
44+
self._last_log_ts = datetime.datetime.now()
45+
self._heartbeats = 0
46+
47+
def is_alive(self) -> bool:
48+
return (
49+
datetime.datetime.now() - self.last_heartbeat
50+
).total_seconds() < self.timeout_seconds

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import cv2
99
import numpy as np
10+
from aioice import ice
1011
from aiortc import (
1112
RTCConfiguration,
1213
RTCDataChannel,
@@ -202,6 +203,7 @@ def __init__(
202203
termination_date: Optional[datetime.datetime] = None,
203204
terminate_event: Optional[asyncio.Event] = None,
204205
use_data_channel_frames: bool = False,
206+
heartbeat_callback: Optional[Callable[[], None]] = None,
205207
):
206208
self._loop = asyncio_loop
207209
self._termination_date = termination_date
@@ -212,6 +214,7 @@ def __init__(
212214
self._received_frames = 0
213215
self._declared_fps = declared_fps
214216
self._stop_processing = False
217+
self.heartbeat_callback = heartbeat_callback
215218
self.use_data_channel_frames = use_data_channel_frames
216219
self._data_frame_queue: "asyncio.Queue[Optional[VideoFrame]]" = asyncio.Queue()
217220
self._chunk_reassembler = (
@@ -267,8 +270,8 @@ def _check_termination(self):
267270
if (
268271
self._termination_date
269272
and self._termination_date < datetime.datetime.now()
270-
and self._terminate_event
271-
and not self._terminate_event.is_set()
273+
or self._terminate_event
274+
and self._terminate_event.is_set()
272275
):
273276
logger.info("Timeout reached, terminating inference pipeline")
274277
self._terminate_event.set()
@@ -401,6 +404,8 @@ async def process_frames_data_only(self):
401404
while not self._stop_processing:
402405
if self._check_termination():
403406
break
407+
if self.heartbeat_callback:
408+
self.heartbeat_callback()
404409

405410
# Get frame from appropriate source
406411
if self.use_data_channel_frames:
@@ -547,6 +552,7 @@ def __init__(
547552
termination_date: Optional[datetime.datetime] = None,
548553
terminate_event: Optional[asyncio.Event] = None,
549554
use_data_channel_frames: bool = False,
555+
heartbeat_callback: Optional[Callable[[], None]] = None,
550556
*args,
551557
**kwargs,
552558
):
@@ -564,6 +570,7 @@ def __init__(
564570
terminate_event=terminate_event,
565571
use_data_channel_frames=use_data_channel_frames,
566572
model_manager=model_manager,
573+
heartbeat_callback=heartbeat_callback,
567574
)
568575

569576
async def _auto_detect_stream_output(
@@ -589,6 +596,9 @@ async def recv(self):
589596
av_logging.set_libav_level(av_logging.ERROR)
590597
self._av_logging_set = True
591598

599+
if self.heartbeat_callback:
600+
self.heartbeat_callback()
601+
592602
# Check if we should terminate
593603
if self._check_termination():
594604
raise MediaStreamError("Processing terminated due to timeout")
@@ -649,7 +659,19 @@ async def init_rtc_peer_connection_with_loop(
649659
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
650660
model_manager: Optional[ModelManager] = None,
651661
shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE,
662+
heartbeat_callback: Optional[Callable[[], None]] = None,
652663
) -> RTCPeerConnectionWithLoop:
664+
# ice._mdns is instantiated on the module level, it has a lock that is bound to the event loop
665+
# avoid RuntimeError: asyncio.locks.Lock is bound to a different event loop
666+
if hasattr(ice, "_mdns"):
667+
if hasattr(ice._mdns, "lock"):
668+
logger.info("Removing lock from aioice.ice._mdns")
669+
delattr(ice._mdns, "lock")
670+
else:
671+
logger.warning(
672+
"aioice.ice implementation was changed, _mdns attribute is not available"
673+
)
674+
653675
termination_date = None
654676
terminate_event = asyncio.Event()
655677

@@ -708,6 +730,7 @@ async def init_rtc_peer_connection_with_loop(
708730
termination_date=termination_date,
709731
terminate_event=terminate_event,
710732
use_data_channel_frames=webrtc_request.use_data_channel_frames,
733+
heartbeat_callback=heartbeat_callback,
711734
)
712735
else:
713736
# No video track - use base VideoFrameProcessor
@@ -723,6 +746,7 @@ async def init_rtc_peer_connection_with_loop(
723746
termination_date=termination_date,
724747
terminate_event=terminate_event,
725748
use_data_channel_frames=webrtc_request.use_data_channel_frames,
749+
heartbeat_callback=heartbeat_callback,
726750
)
727751
except (
728752
ValidationError,

inference/core/workflows/core_steps/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@
164164
from inference.core.workflows.core_steps.models.foundation.anthropic_claude.v1 import (
165165
AnthropicClaudeBlockV1,
166166
)
167+
from inference.core.workflows.core_steps.models.foundation.anthropic_claude.v2 import (
168+
AnthropicClaudeBlockV2,
169+
)
167170
from inference.core.workflows.core_steps.models.foundation.clip.v1 import (
168171
ClipModelBlockV1,
169172
)
@@ -582,6 +585,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
582585
DimensionCollapseBlockV1,
583586
FirstNonEmptyOrDefaultBlockV1,
584587
AnthropicClaudeBlockV1,
588+
AnthropicClaudeBlockV2,
585589
CosineSimilarityBlockV1,
586590
BackgroundColorVisualizationBlockV1,
587591
BarcodeDetectorBlockV1,

0 commit comments

Comments
 (0)