Skip to content

Commit 1877720

Browse files
Add text mode (#321)
* Pretty good spot * Working draft * Fix other mode * Add js to git * Working * Add code * fix * Fix * Add code * Fix submit race condition * demo * fix * Fix * Fix
1 parent 1179f8e commit 1877720

File tree

69 files changed

+110233
-22961
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+110233
-22961
lines changed

backend/fastrtc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AdditionalOutputs,
3636
CloseStream,
3737
Warning,
38+
WebRTCData,
3839
WebRTCError,
3940
aggregate_bytes_to_16bit,
4041
async_aggregate_bytes_to_16bit,
@@ -92,4 +93,5 @@
9293
"CloseStream",
9394
"get_current_context",
9495
"CartesiaTTSOptions",
96+
"WebRTCData",
9597
]

backend/fastrtc/reply_on_pause.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
1313
from .tracks import EmitType, StreamHandler
14-
from .utils import AdditionalOutputs, create_message, split_output
14+
from .utils import AdditionalOutputs, WebRTCData, create_message, split_output
1515

1616
logger = getLogger(__name__)
1717

@@ -67,6 +67,14 @@ def new(self):
6767
[tuple[int, NDArray[np.int16]], Any],
6868
AsyncGenerator[EmitType, None],
6969
]
70+
| Callable[
71+
[WebRTCData],
72+
Generator[EmitType, None, None],
73+
]
74+
| Callable[
75+
[WebRTCData, Any],
76+
AsyncGenerator[EmitType, None],
77+
]
7078
)
7179

7280

@@ -115,6 +123,7 @@ def __init__(
115123
output_frame_size: int | None = None, # Deprecated
116124
input_sample_rate: int = 48000,
117125
model: PauseDetectionModel | None = None,
126+
needs_args: bool = False,
118127
):
119128
"""
120129
Initializes the ReplyOnPause handler.
@@ -132,6 +141,7 @@ def __init__(
132141
output_frame_size: Deprecated.
133142
input_sample_rate: The expected sample rate of incoming audio.
134143
model: An optional pre-initialized VAD model instance.
144+
needs_args: Whether the reply function expects additional arguments.
135145
"""
136146
super().__init__(
137147
expected_layout,
@@ -152,11 +162,12 @@ def __init__(
152162
self.model_options = model_options
153163
self.algo_options = algo_options or AlgoOptions()
154164
self.startup_fn = startup_fn
165+
self.needs_args = needs_args
155166

156167
@property
157168
def _needs_additional_inputs(self) -> bool:
158169
"""Checks if the reply function `fn` expects additional arguments."""
159-
return len(inspect.signature(self.fn).parameters) > 1
170+
return len(inspect.signature(self.fn).parameters) > 1 or self.needs_args
160171

161172
def start_up(self):
162173
"""
@@ -187,6 +198,7 @@ def copy(self):
187198
self.output_frame_size,
188199
self.input_sample_rate,
189200
self.model,
201+
self.needs_args,
190202
)
191203

192204
def determine_pause(
@@ -361,19 +373,21 @@ def emit(self):
361373
else:
362374
if not self.generator:
363375
self.send_message_sync(create_message("log", "pause_detected"))
364-
if self._needs_additional_inputs and not self.args_set.is_set():
365-
if not self.phone_mode:
366-
self.wait_for_args_sync()
367-
else:
368-
self.latest_args = [None]
369-
self.args_set.set()
376+
if self._needs_additional_inputs and not self.phone_mode:
377+
self.wait_for_args_sync()
378+
else:
379+
self.latest_args = [None]
380+
self.args_set.set()
370381
logger.debug("Creating generator")
371-
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
372-
if self._needs_additional_inputs:
373-
self.latest_args[0] = (self.state.sampling_rate, audio)
374-
self.generator = self.fn(*self.latest_args) # type: ignore
382+
if self.state.stream is not None and self.state.stream.size > 0:
383+
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
375384
else:
376-
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
385+
audio = np.array([[]], dtype=np.int16)
386+
if isinstance(self.latest_args[0], WebRTCData):
387+
self.latest_args[0].audio = (self.state.sampling_rate, audio)
388+
else:
389+
self.latest_args[0] = (self.state.sampling_rate, audio)
390+
self.generator = self.fn(*self.latest_args) # type: ignore
377391
logger.debug("Latest args: %s", self.latest_args)
378392
self.state = self.state.new()
379393
self.state.responding = True

backend/fastrtc/reply_on_stopwords.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
output_frame_size: int | None = None, # Deprecated
6262
input_sample_rate: int = 48000,
6363
model: PauseDetectionModel | None = None,
64+
needs_args: bool = False,
6465
):
6566
"""
6667
Initializes the ReplyOnStopWords handler.
@@ -80,6 +81,7 @@ def __init__(
8081
output_frame_size: Deprecated.
8182
input_sample_rate: The expected sample rate of incoming audio.
8283
model: An optional pre-initialized VAD model instance.
84+
needs_args: Whether the reply function expects additional arguments.
8385
"""
8486
super().__init__(
8587
fn,
@@ -92,6 +94,7 @@ def __init__(
9294
output_frame_size=output_frame_size,
9395
input_sample_rate=input_sample_rate,
9496
model=model,
97+
needs_args=needs_args,
9598
)
9699
self.stop_words = stop_words
97100
self.state = ReplyOnStopWordsState()
@@ -236,4 +239,5 @@ def copy(self):
236239
self.output_frame_size,
237240
self.input_sample_rate,
238241
self.model,
242+
self.needs_args,
239243
)

backend/fastrtc/stream.py

Lines changed: 139 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ def __init__(
141141
self.modality = modality
142142
self.rtp_params = rtp_params
143143
self.event_handler = handler
144+
if (
145+
ui_args
146+
and ui_args.get("variant") == "textbox"
147+
and hasattr(handler, "needs_args")
148+
):
149+
self.event_handler.needs_args = True # type: ignore
150+
else:
151+
self.event_handler.needs_args = False # type: ignore
152+
144153
self.concurrency_limit = cast(
145154
(int),
146155
1 if concurrency_limit in ["default", None] else concurrency_limit,
@@ -574,28 +583,58 @@ def _generate_default_ui(
574583
</div>
575584
"""
576585
)
577-
with gr.Row():
578-
with gr.Column():
579-
with gr.Group():
580-
image = WebRTC(
581-
label="Stream",
582-
rtc_configuration=self.rtc_configuration,
583-
track_constraints=self.track_constraints,
584-
mode="send",
585-
modality="audio",
586-
icon=ui_args.get("icon"),
587-
icon_button_color=ui_args.get("icon_button_color"),
588-
pulse_color=ui_args.get("pulse_color"),
589-
icon_radius=ui_args.get("icon_radius"),
590-
)
591-
self.webrtc_component = image
592-
for component in additional_input_components:
593-
if component not in same_components:
586+
if ui_args.get("variant", "textbox"):
587+
with gr.Row():
588+
if additional_input_components:
589+
with gr.Column():
590+
for component in additional_input_components:
594591
component.render()
595-
if additional_output_components:
592+
diff_output_components = [
593+
component
594+
for component in additional_output_components
595+
if component not in same_components
596+
]
597+
if diff_output_components:
598+
with gr.Column():
599+
for component in diff_output_components:
600+
component.render()
601+
with gr.Row():
602+
image = WebRTC(
603+
label="Stream",
604+
rtc_configuration=self.rtc_configuration,
605+
track_constraints=self.track_constraints,
606+
mode="send",
607+
modality="audio",
608+
icon=ui_args.get("icon"),
609+
icon_button_color=ui_args.get("icon_button_color"),
610+
pulse_color=ui_args.get("pulse_color"),
611+
icon_radius=ui_args.get("icon_radius"),
612+
variant=ui_args.get("variant", "wave"),
613+
)
614+
else:
615+
with gr.Row():
596616
with gr.Column():
597-
for component in additional_output_components:
598-
component.render()
617+
with gr.Group():
618+
image = WebRTC(
619+
label="Stream",
620+
rtc_configuration=self.rtc_configuration,
621+
track_constraints=self.track_constraints,
622+
mode="send",
623+
modality="audio",
624+
icon=ui_args.get("icon"),
625+
icon_button_color=ui_args.get("icon_button_color"),
626+
pulse_color=ui_args.get("pulse_color"),
627+
icon_radius=ui_args.get("icon_radius"),
628+
variant=ui_args.get("variant", "wave"),
629+
)
630+
for component in additional_input_components:
631+
if component not in same_components:
632+
component.render()
633+
if additional_output_components:
634+
with gr.Column():
635+
for component in additional_output_components:
636+
component.render()
637+
self.webrtc_component = image
599638
image.stream(
600639
fn=self.event_handler,
601640
inputs=[image] + additional_input_components,
@@ -630,45 +669,89 @@ def _generate_default_ui(
630669
</div>
631670
"""
632671
)
633-
with gr.Row():
634-
with gr.Column():
635-
with gr.Group():
636-
image = WebRTC(
637-
label="Stream",
638-
rtc_configuration=self.rtc_configuration,
639-
track_constraints=self.track_constraints,
640-
mode="send-receive",
641-
modality="audio",
642-
icon=ui_args.get("icon"),
643-
icon_button_color=ui_args.get("icon_button_color"),
644-
pulse_color=ui_args.get("pulse_color"),
645-
icon_radius=ui_args.get("icon_radius"),
646-
)
647-
self.webrtc_component = image
648-
for component in additional_input_components:
649-
if component not in same_components:
672+
if ui_args.get("variant", "") == "textbox":
673+
with gr.Row():
674+
if additional_input_components:
675+
with gr.Column():
676+
for component in additional_input_components:
650677
component.render()
678+
diff_output_components = [
679+
component
680+
for component in additional_output_components
681+
if component not in same_components
682+
]
683+
if diff_output_components:
684+
with gr.Column():
685+
for component in diff_output_components:
686+
component.render()
687+
with gr.Row():
688+
image = WebRTC(
689+
label="Stream",
690+
rtc_configuration=self.rtc_configuration,
691+
track_constraints=self.track_constraints,
692+
mode="send-receive",
693+
modality="audio",
694+
icon=ui_args.get("icon"),
695+
icon_button_color=ui_args.get("icon_button_color"),
696+
pulse_color=ui_args.get("pulse_color"),
697+
icon_radius=ui_args.get("icon_radius"),
698+
variant=ui_args.get("variant", "wave"),
699+
)
700+
else:
651701
if additional_output_components:
652-
with gr.Column():
653-
for component in additional_output_components:
654-
component.render()
655-
656-
image.stream(
657-
fn=self.event_handler,
658-
inputs=[image] + additional_input_components,
659-
outputs=[image],
660-
time_limit=self.time_limit,
661-
concurrency_limit=self.concurrency_limit, # type: ignore
662-
send_input_on=ui_args.get("send_input_on", "change"),
702+
with gr.Row():
703+
with gr.Column():
704+
image = WebRTC(
705+
label="Stream",
706+
rtc_configuration=self.rtc_configuration,
707+
track_constraints=self.track_constraints,
708+
mode="send-receive",
709+
modality="audio",
710+
icon=ui_args.get("icon"),
711+
icon_button_color=ui_args.get("icon_button_color"),
712+
pulse_color=ui_args.get("pulse_color"),
713+
icon_radius=ui_args.get("icon_radius"),
714+
)
715+
for component in additional_input_components:
716+
if component not in same_components:
717+
component.render()
718+
with gr.Column():
719+
for component in additional_output_components:
720+
component.render()
721+
else:
722+
with gr.Row():
723+
with gr.Column():
724+
image = WebRTC(
725+
label="Stream",
726+
rtc_configuration=self.rtc_configuration,
727+
track_constraints=self.track_constraints,
728+
mode="send-receive",
729+
modality="audio",
730+
icon=ui_args.get("icon"),
731+
icon_button_color=ui_args.get("icon_button_color"),
732+
pulse_color=ui_args.get("pulse_color"),
733+
icon_radius=ui_args.get("icon_radius"),
734+
)
735+
for component in additional_input_components:
736+
if component not in same_components:
737+
component.render()
738+
self.webrtc_component = image
739+
image.stream(
740+
fn=self.event_handler,
741+
inputs=[image] + additional_input_components,
742+
outputs=[image],
743+
time_limit=self.time_limit,
744+
concurrency_limit=self.concurrency_limit, # type: ignore
745+
send_input_on=ui_args.get("send_input_on", "change"),
746+
)
747+
if additional_output_components:
748+
assert self.additional_outputs_handler
749+
image.on_additional_outputs(
750+
self.additional_outputs_handler,
751+
inputs=additional_output_components,
752+
outputs=additional_output_components,
753+
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
663754
)
664-
if additional_output_components:
665-
assert self.additional_outputs_handler
666-
image.on_additional_outputs(
667-
self.additional_outputs_handler,
668-
inputs=additional_output_components,
669-
outputs=additional_output_components,
670-
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
671-
)
672755
elif self.modality == "audio-video" and self.mode == "send-receive":
673756
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
674757
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""

0 commit comments

Comments
 (0)