Skip to content

Commit 7f5f7a4

Browse files
Merge pull request #1722 from roboflow/feat/modal-exec-time
Modal parametrization / webrtc turn config / perspective correction fix
2 parents 10f69ed + 5399eb1 commit 7f5f7a4

File tree

11 files changed

+260
-59
lines changed

11 files changed

+260
-59
lines changed

inference/core/env.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,10 @@
705705
WEBRTC_MODAL_FUNCTION_TIME_LIMIT = int(
706706
os.getenv("WEBRTC_MODAL_FUNCTION_TIME_LIMIT", "3600")
707707
)
708+
# seconds
709+
WEBRTC_MODAL_FUNCTION_MAX_TIME_LIMIT = int(
710+
os.getenv("WEBRTC_MODAL_FUNCTION_MAX_TIME_LIMIT", "604800") # 7 days
711+
)
708712
WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT = str2bool(
709713
os.getenv("WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT", "True")
710714
)

inference/core/interfaces/webrtc_worker/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from inference.core.env import WEBRTC_MODAL_TOKEN_ID, WEBRTC_MODAL_TOKEN_SECRET
55
from inference.core.interfaces.webrtc_worker.cpu import rtc_peer_connection_process
66
from inference.core.interfaces.webrtc_worker.entities import (
7+
RTCIceServer,
8+
WebRTCConfig,
79
WebRTCWorkerRequest,
810
WebRTCWorkerResult,
911
)
@@ -12,6 +14,17 @@
1214
async def start_worker(
1315
webrtc_request: WebRTCWorkerRequest,
1416
) -> WebRTCWorkerResult:
17+
if webrtc_request.webrtc_turn_config:
18+
webrtc_request.webrtc_config = WebRTCConfig(
19+
iceServers=[
20+
RTCIceServer(
21+
urls=[webrtc_request.webrtc_turn_config.urls],
22+
username=webrtc_request.webrtc_turn_config.username,
23+
credential=webrtc_request.webrtc_turn_config.credential,
24+
)
25+
]
26+
)
27+
1528
if WEBRTC_MODAL_TOKEN_ID and WEBRTC_MODAL_TOKEN_SECRET:
1629
try:
1730
from inference.core.interfaces.webrtc_worker.modal import (

inference/core/interfaces/webrtc_worker/entities.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,22 @@
1414
)
1515

1616

17+
class RTCIceServer(BaseModel):
18+
urls: List[str]
19+
username: Optional[str] = None
20+
credential: Optional[str] = None
21+
22+
23+
class WebRTCConfig(BaseModel):
24+
iceServers: List[RTCIceServer]
25+
26+
1727
class WebRTCWorkerRequest(BaseModel):
1828
api_key: Optional[str] = None
1929
workflow_configuration: WorkflowConfiguration
2030
webrtc_offer: WebRTCOffer
31+
webrtc_config: Optional[WebRTCConfig] = None
32+
# TODO: to be removed, replaced with webrtc_config
2133
webrtc_turn_config: Optional[WebRTCTURNConfig] = None
2234
webrtc_realtime_processing: bool = (
2335
WEBRTC_REALTIME_PROCESSING # when set to True, MediaRelay.subscribe will be called with buffered=False
@@ -27,21 +39,11 @@ class WebRTCWorkerRequest(BaseModel):
2739
declared_fps: Optional[float] = None
2840
rtsp_url: Optional[str] = None
2941
processing_timeout: Optional[int] = WEBRTC_MODAL_FUNCTION_TIME_LIMIT
30-
# https://modal.com/docs/guide/gpu#specifying-gpu-type
31-
requested_gpu: Optional[
32-
Literal[
33-
"T4",
34-
"L4",
35-
"A10",
36-
"A100",
37-
"A100-40GB",
38-
"A100-80GB",
39-
"L40S",
40-
"H100/H100!",
41-
"H200",
42-
"B200",
43-
]
44-
] = "T4"
42+
requested_plan: Optional[str] = "webrtc-gpu-small"
43+
# TODO: replaced with requested_plan
44+
requested_gpu: Optional[str] = None
45+
# must be valid region: https://modal.com/docs/guide/region-selection#region-options
46+
requested_region: Optional[str] = None
4547

4648

4749
class WebRTCVideoMetadata(BaseModel):

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 118 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2+
import datetime
23
from pathlib import Path
4+
from typing import Dict, Optional
35

46
from inference.core import logger
57
from inference.core.env import (
@@ -20,6 +22,7 @@
2022
WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT,
2123
WEBRTC_MODAL_FUNCTION_GPU,
2224
WEBRTC_MODAL_FUNCTION_MAX_INPUTS,
25+
WEBRTC_MODAL_FUNCTION_MAX_TIME_LIMIT,
2326
WEBRTC_MODAL_FUNCTION_MIN_CONTAINERS,
2427
WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW,
2528
WEBRTC_MODAL_FUNCTION_TIME_LIMIT,
@@ -33,6 +36,7 @@
3336
WEBRTC_MODAL_TOKEN_SECRET,
3437
WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE,
3538
)
39+
from inference.core.exceptions import RoboflowAPIUnsuccessfulRequestError
3640
from inference.core.interfaces.webrtc_worker.entities import (
3741
WebRTCWorkerRequest,
3842
WebRTCWorkerResult,
@@ -41,6 +45,8 @@
4145
init_rtc_peer_connection_with_loop,
4246
)
4347
from inference.core.version import __version__
48+
from inference.usage_tracking.collector import usage_collector
49+
from inference.usage_tracking.plan_details import WebRTCPlan
4450

4551
try:
4652
import modal
@@ -118,13 +124,42 @@
118124
}
119125

120126
class RTCPeerConnectionModal:
127+
_webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None)
128+
_exec_session_started: Optional[datetime.datetime] = modal.parameter(
129+
default=None
130+
)
131+
_exec_session_stopped: Optional[datetime.datetime] = modal.parameter(
132+
default=None
133+
)
134+
121135
@modal.method()
122136
def rtc_peer_connection_modal(
123137
self,
124138
webrtc_request: WebRTCWorkerRequest,
125139
q: modal.Queue,
126140
):
127-
logger.info("Received webrtc offer")
141+
logger.info("*** Spawning %s:", self.__class__.__name__)
142+
logger.info(
143+
"webrtc_realtime_processing: %s",
144+
webrtc_request.webrtc_realtime_processing,
145+
)
146+
logger.info("stream_output: %s", webrtc_request.stream_output)
147+
logger.info("data_output: %s", webrtc_request.data_output)
148+
logger.info("declared_fps: %s", webrtc_request.declared_fps)
149+
logger.info("rtsp_url: %s", webrtc_request.rtsp_url)
150+
logger.info("processing_timeout: %s", webrtc_request.processing_timeout)
151+
logger.info("requested_plan: %s", webrtc_request.requested_plan)
152+
logger.info("requested_gpu: %s", webrtc_request.requested_gpu)
153+
logger.info("requested_region: %s", webrtc_request.requested_region)
154+
logger.info(
155+
"ICE servers: %s",
156+
len(
157+
webrtc_request.webrtc_config.iceServers
158+
if webrtc_request.webrtc_config
159+
else []
160+
),
161+
)
162+
self._webrtc_request = webrtc_request
128163

129164
def send_answer(obj: WebRTCWorkerResult):
130165
logger.info("Sending webrtc answer")
@@ -137,22 +172,56 @@ def send_answer(obj: WebRTCWorkerResult):
137172
)
138173
)
139174

175+
# https://modal.com/docs/reference/modal.enter
176+
# Modal usage calculation is relying on no concurrency and no hot instances
177+
@modal.enter()
178+
def start(self):
179+
self._exec_session_started = datetime.datetime.now()
180+
181+
@modal.exit()
182+
def stop(self):
183+
if not self._webrtc_request:
184+
return
185+
self._exec_session_stopped = datetime.datetime.now()
186+
workflow_id = self._webrtc_request.workflow_configuration.workflow_id
187+
if not workflow_id:
188+
if self._webrtc_request.workflow_configuration.workflow_specification:
189+
workflow_id = usage_collector._calculate_resource_hash(
190+
resource_details=self._webrtc_request.workflow_configuration.workflow_specification
191+
)
192+
else:
193+
workflow_id = "unknown"
194+
195+
# requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
196+
webrtc_plan = self._webrtc_request.requested_plan
197+
198+
usage_collector.record_usage(
199+
source=workflow_id,
200+
category="modal",
201+
api_key=self._webrtc_request.api_key,
202+
resource_details={"plan": webrtc_plan},
203+
execution_duration=(
204+
self._exec_session_stopped - self._exec_session_started
205+
).total_seconds(),
206+
)
207+
usage_collector.push_usage_payloads()
208+
140209
# Modal derives function name from class name
141210
# https://modal.com/docs/reference/modal.App#cls
142211
@app.cls(
143-
**{
144-
**decorator_kwargs,
145-
"enable_memory_snapshot": True,
146-
}
212+
**decorator_kwargs,
147213
)
148214
class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
149215
pass
150216

151217
@app.cls(
152218
**{
153219
**decorator_kwargs,
154-
"gpu": WEBRTC_MODAL_FUNCTION_GPU,
155-
"experimental_options": {"enable_gpu_snapshot": True},
220+
"enable_memory_snapshot": False,
221+
"gpu": WEBRTC_MODAL_FUNCTION_GPU, # https://modal.com/docs/guide/gpu#specifying-gpu-type
222+
"experimental_options": {
223+
"enable_gpu_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT
224+
},
156225
}
157226
)
158227
class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
@@ -161,6 +230,31 @@ class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
161230
def spawn_rtc_peer_connection_modal(
162231
webrtc_request: WebRTCWorkerRequest,
163232
) -> WebRTCWorkerResult:
233+
webrtc_plans: Optional[Dict[str, WebRTCPlan]] = (
234+
usage_collector._plan_details.get_webrtc_plans(
235+
api_key=webrtc_request.api_key
236+
)
237+
)
238+
if webrtc_plans and webrtc_request.requested_plan:
239+
if webrtc_request.requested_plan not in webrtc_plans:
240+
raise RoboflowAPIUnsuccessfulRequestError(
241+
f"Unknown requested plan {webrtc_request.requested_plan}"
242+
)
243+
webrtc_request.requested_gpu = webrtc_plans[
244+
webrtc_request.requested_plan
245+
].gpu
246+
if (
247+
webrtc_plans
248+
and not webrtc_request.requested_plan
249+
and webrtc_request.requested_gpu
250+
):
251+
gpu_to_plan = {v.gpu: k for k, v in webrtc_plans.items()}
252+
if webrtc_request.requested_gpu not in gpu_to_plan:
253+
raise RoboflowAPIUnsuccessfulRequestError(
254+
f"Requested gpu {webrtc_request.requested_gpu} not associated with any plan"
255+
)
256+
webrtc_request.requested_plan = gpu_to_plan[webrtc_request.requested_gpu]
257+
164258
# https://modal.com/docs/reference/modal.Client#from_credentials
165259
client = modal.Client.from_credentials(
166260
token_id=WEBRTC_MODAL_TOKEN_ID,
@@ -186,27 +280,32 @@ def spawn_rtc_peer_connection_modal(
186280
)
187281
deployed_cls.hydrate(client=client)
188282
if webrtc_request.processing_timeout is None:
189-
logger.warning("Spawning webrtc modal function without timeout")
190-
else:
191-
logger.info(
192-
"Spawning webrtc modal function with timeout %s",
193-
webrtc_request.processing_timeout,
194-
)
283+
webrtc_request.processing_timeout = WEBRTC_MODAL_FUNCTION_MAX_TIME_LIMIT
284+
logger.warning("No timeout specified, using max timeout")
285+
logger.info(
286+
"Spawning webrtc modal function with timeout %s",
287+
webrtc_request.processing_timeout,
288+
)
195289
# https://modal.com/docs/reference/modal.Cls#with_options
196290
cls_with_options = deployed_cls.with_options(
197291
timeout=webrtc_request.processing_timeout,
198292
)
199-
if (
200-
webrtc_request.requested_gpu is not None
201-
and webrtc_request.requested_gpu != WEBRTC_MODAL_FUNCTION_GPU
202-
):
203-
logger.warning(
204-
"Spawning webrtc modal function with custom gpu %s",
293+
if webrtc_request.requested_gpu is not None:
294+
logger.info(
295+
"Spawning webrtc modal function with gpu %s",
205296
webrtc_request.requested_gpu,
206297
)
207298
cls_with_options = cls_with_options.with_options(
208299
gpu=webrtc_request.requested_gpu,
209300
)
301+
if webrtc_request.requested_region:
302+
logger.info(
303+
"Spawning webrtc modal function with region %s",
304+
webrtc_request.requested_region,
305+
)
306+
cls_with_options = cls_with_options.with_options(
307+
region=webrtc_request.requested_region,
308+
)
210309
rtc_modal_obj: RTCPeerConnectionModal = cls_with_options()
211310
# https://modal.com/docs/reference/modal.Queue#ephemeral
212311
with modal.Queue.ephemeral(client=client) as q:

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,18 @@ async def _wait_ice_complete(peer_connection: RTCPeerConnectionWithLoop, timeout
446446

447447
@peer_connection.on("icegatheringstatechange")
448448
def _():
449+
logger.info(
450+
"ICE gathering state changed to %s", peer_connection.iceGatheringState
451+
)
449452
if not fut.done() and peer_connection.iceGatheringState == "complete":
450453
fut.set_result(True)
451454

452455
try:
456+
logger.info("Waiting for ICE gathering to complete...")
453457
await asyncio.wait_for(fut, timeout)
458+
logger.info("ICE gathering completed")
454459
except asyncio.TimeoutError:
460+
logger.info("ICE gathering did not complete in %s seconds", timeout)
455461
pass
456462

457463

@@ -466,10 +472,16 @@ async def init_rtc_peer_connection_with_loop(
466472
if webrtc_request.processing_timeout is not None:
467473
try:
468474
time_limit_seconds = int(webrtc_request.processing_timeout)
469-
termination_date = datetime.datetime.now() + datetime.timedelta(
475+
datetime_now = datetime.datetime.now()
476+
termination_date = datetime_now + datetime.timedelta(
470477
seconds=time_limit_seconds - 1
471478
)
472-
logger.info("Setting termination date to %s", termination_date)
479+
logger.info(
480+
"Setting termination date to %s (%s seconds from %s)",
481+
termination_date,
482+
time_limit_seconds,
483+
datetime_now,
484+
)
473485
except (TypeError, ValueError):
474486
pass
475487
if webrtc_request.stream_output is None:
@@ -578,20 +590,22 @@ async def init_rtc_peer_connection_with_loop(
578590
)
579591
return
580592

581-
if webrtc_request.webrtc_turn_config:
582-
turn_server = RTCIceServer(
583-
urls=[webrtc_request.webrtc_turn_config.urls],
584-
username=webrtc_request.webrtc_turn_config.username,
585-
credential=webrtc_request.webrtc_turn_config.credential,
586-
)
587-
peer_connection = RTCPeerConnectionWithLoop(
588-
configuration=RTCConfiguration(iceServers=[turn_server]),
589-
asyncio_loop=asyncio_loop,
590-
)
593+
if webrtc_request.webrtc_config is not None:
594+
ice_servers = []
595+
for ice_server in webrtc_request.webrtc_config.iceServers:
596+
ice_servers.append(
597+
RTCIceServer(
598+
urls=ice_server.urls,
599+
username=ice_server.username,
600+
credential=ice_server.credential,
601+
)
602+
)
591603
else:
592-
peer_connection = RTCPeerConnectionWithLoop(
593-
asyncio_loop=asyncio_loop,
594-
)
604+
ice_servers = None
605+
peer_connection = RTCPeerConnectionWithLoop(
606+
configuration=RTCConfiguration(iceServers=ice_servers) if ice_servers else None,
607+
asyncio_loop=asyncio_loop,
608+
)
595609

596610
relay = MediaRelay()
597611

inference/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.60.1rc2"
1+
__version__ = "0.60.1"
22

33

44
if __name__ == "__main__":

0 commit comments

Comments
 (0)