@@ -141,6 +141,15 @@ def __init__(
141141 self .modality = modality
142142 self .rtp_params = rtp_params
143143 self .event_handler = handler
144+ if (
145+ ui_args
146+ and ui_args .get ("variant" ) == "textbox"
147+ and hasattr (handler , "needs_args" )
148+ ):
149+ self .event_handler .needs_args = True # type: ignore
150+ else :
151+ self .event_handler .needs_args = False # type: ignore
152+
144153 self .concurrency_limit = cast (
145154 (int ),
146155 1 if concurrency_limit in ["default" , None ] else concurrency_limit ,
@@ -574,28 +583,58 @@ def _generate_default_ui(
574583 </div>
575584 """
576585 )
577- with gr .Row ():
578- with gr .Column ():
579- with gr .Group ():
580- image = WebRTC (
581- label = "Stream" ,
582- rtc_configuration = self .rtc_configuration ,
583- track_constraints = self .track_constraints ,
584- mode = "send" ,
585- modality = "audio" ,
586- icon = ui_args .get ("icon" ),
587- icon_button_color = ui_args .get ("icon_button_color" ),
588- pulse_color = ui_args .get ("pulse_color" ),
589- icon_radius = ui_args .get ("icon_radius" ),
590- )
591- self .webrtc_component = image
592- for component in additional_input_components :
593- if component not in same_components :
586+ if ui_args .get ("variant" , "textbox" ):
587+ with gr .Row ():
588+ if additional_input_components :
589+ with gr .Column ():
590+ for component in additional_input_components :
594591 component .render ()
595- if additional_output_components :
592+ diff_output_components = [
593+ component
594+ for component in additional_output_components
595+ if component not in same_components
596+ ]
597+ if diff_output_components :
598+ with gr .Column ():
599+ for component in diff_output_components :
600+ component .render ()
601+ with gr .Row ():
602+ image = WebRTC (
603+ label = "Stream" ,
604+ rtc_configuration = self .rtc_configuration ,
605+ track_constraints = self .track_constraints ,
606+ mode = "send" ,
607+ modality = "audio" ,
608+ icon = ui_args .get ("icon" ),
609+ icon_button_color = ui_args .get ("icon_button_color" ),
610+ pulse_color = ui_args .get ("pulse_color" ),
611+ icon_radius = ui_args .get ("icon_radius" ),
612+ variant = ui_args .get ("variant" , "wave" ),
613+ )
614+ else :
615+ with gr .Row ():
596616 with gr .Column ():
597- for component in additional_output_components :
598- component .render ()
617+ with gr .Group ():
618+ image = WebRTC (
619+ label = "Stream" ,
620+ rtc_configuration = self .rtc_configuration ,
621+ track_constraints = self .track_constraints ,
622+ mode = "send" ,
623+ modality = "audio" ,
624+ icon = ui_args .get ("icon" ),
625+ icon_button_color = ui_args .get ("icon_button_color" ),
626+ pulse_color = ui_args .get ("pulse_color" ),
627+ icon_radius = ui_args .get ("icon_radius" ),
628+ variant = ui_args .get ("variant" , "wave" ),
629+ )
630+ for component in additional_input_components :
631+ if component not in same_components :
632+ component .render ()
633+ if additional_output_components :
634+ with gr .Column ():
635+ for component in additional_output_components :
636+ component .render ()
637+ self .webrtc_component = image
599638 image .stream (
600639 fn = self .event_handler ,
601640 inputs = [image ] + additional_input_components ,
@@ -630,45 +669,89 @@ def _generate_default_ui(
630669 </div>
631670 """
632671 )
633- with gr .Row ():
634- with gr .Column ():
635- with gr .Group ():
636- image = WebRTC (
637- label = "Stream" ,
638- rtc_configuration = self .rtc_configuration ,
639- track_constraints = self .track_constraints ,
640- mode = "send-receive" ,
641- modality = "audio" ,
642- icon = ui_args .get ("icon" ),
643- icon_button_color = ui_args .get ("icon_button_color" ),
644- pulse_color = ui_args .get ("pulse_color" ),
645- icon_radius = ui_args .get ("icon_radius" ),
646- )
647- self .webrtc_component = image
648- for component in additional_input_components :
649- if component not in same_components :
672+ if ui_args .get ("variant" , "" ) == "textbox" :
673+ with gr .Row ():
674+ if additional_input_components :
675+ with gr .Column ():
676+ for component in additional_input_components :
650677 component .render ()
678+ diff_output_components = [
679+ component
680+ for component in additional_output_components
681+ if component not in same_components
682+ ]
683+ if diff_output_components :
684+ with gr .Column ():
685+ for component in diff_output_components :
686+ component .render ()
687+ with gr .Row ():
688+ image = WebRTC (
689+ label = "Stream" ,
690+ rtc_configuration = self .rtc_configuration ,
691+ track_constraints = self .track_constraints ,
692+ mode = "send-receive" ,
693+ modality = "audio" ,
694+ icon = ui_args .get ("icon" ),
695+ icon_button_color = ui_args .get ("icon_button_color" ),
696+ pulse_color = ui_args .get ("pulse_color" ),
697+ icon_radius = ui_args .get ("icon_radius" ),
698+ variant = ui_args .get ("variant" , "wave" ),
699+ )
700+ else :
651701 if additional_output_components :
652- with gr .Column ():
653- for component in additional_output_components :
654- component .render ()
655-
656- image .stream (
657- fn = self .event_handler ,
658- inputs = [image ] + additional_input_components ,
659- outputs = [image ],
660- time_limit = self .time_limit ,
661- concurrency_limit = self .concurrency_limit , # type: ignore
662- send_input_on = ui_args .get ("send_input_on" , "change" ),
702+ with gr .Row ():
703+ with gr .Column ():
704+ image = WebRTC (
705+ label = "Stream" ,
706+ rtc_configuration = self .rtc_configuration ,
707+ track_constraints = self .track_constraints ,
708+ mode = "send-receive" ,
709+ modality = "audio" ,
710+ icon = ui_args .get ("icon" ),
711+ icon_button_color = ui_args .get ("icon_button_color" ),
712+ pulse_color = ui_args .get ("pulse_color" ),
713+ icon_radius = ui_args .get ("icon_radius" ),
714+ )
715+ for component in additional_input_components :
716+ if component not in same_components :
717+ component .render ()
718+ with gr .Column ():
719+ for component in additional_output_components :
720+ component .render ()
721+ else :
722+ with gr .Row ():
723+ with gr .Column ():
724+ image = WebRTC (
725+ label = "Stream" ,
726+ rtc_configuration = self .rtc_configuration ,
727+ track_constraints = self .track_constraints ,
728+ mode = "send-receive" ,
729+ modality = "audio" ,
730+ icon = ui_args .get ("icon" ),
731+ icon_button_color = ui_args .get ("icon_button_color" ),
732+ pulse_color = ui_args .get ("pulse_color" ),
733+ icon_radius = ui_args .get ("icon_radius" ),
734+ )
735+ for component in additional_input_components :
736+ if component not in same_components :
737+ component .render ()
738+ self .webrtc_component = image
739+ image .stream (
740+ fn = self .event_handler ,
741+ inputs = [image ] + additional_input_components ,
742+ outputs = [image ],
743+ time_limit = self .time_limit ,
744+ concurrency_limit = self .concurrency_limit , # type: ignore
745+ send_input_on = ui_args .get ("send_input_on" , "change" ),
746+ )
747+ if additional_output_components :
748+ assert self .additional_outputs_handler
749+ image .on_additional_outputs (
750+ self .additional_outputs_handler ,
751+ inputs = additional_output_components ,
752+ outputs = additional_output_components ,
753+ concurrency_limit = self .concurrency_limit_gradio , # type: ignore
663754 )
664- if additional_output_components :
665- assert self .additional_outputs_handler
666- image .on_additional_outputs (
667- self .additional_outputs_handler ,
668- inputs = additional_output_components ,
669- outputs = additional_output_components ,
670- concurrency_limit = self .concurrency_limit_gradio , # type: ignore
671- )
672755 elif self .modality == "audio-video" and self .mode == "send-receive" :
673756 css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
674757 .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
0 commit comments