Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/fastrtc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import numpy as np
from fastapi import WebSocket
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.oauth import OAuthToken
from numpy.typing import NDArray
from pydub import AudioSegment
from starlette.requests import Request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,6 +83,8 @@ def create_message(
class Context:
webrtc_id: str
websocket: WebSocket | None = None
request: Request | None = None
oauth_token: OAuthToken | None = None


current_context: ContextVar[Context | None] = ContextVar(
Expand Down
33 changes: 30 additions & 3 deletions backend/fastrtc/webrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
cast,
)

from gradio import wasm_utils
from gradio import Request, wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file

Expand Down Expand Up @@ -403,9 +403,36 @@ async def turn(self, _):
return {"error": str(e)}

@server
async def offer(self, body):
async def offer(
self,
body,
request: Request | None = None,
):
from gradio import oauth

oauth_token = None

if request is not None:
try:
session = request.session
oauth_info = session.get("oauth_info", None)
oauth_token = (
oauth.OAuthToken(
token=oauth_info["access_token"],
scope=oauth_info["scope"],
expires_at=oauth_info["expires_at"],
)
if oauth_info is not None
else None
)
except Exception:
import traceback

traceback.print_exc()
oauth_token = None

return await self.handle_offer(
body, self.set_additional_outputs(body["webrtc_id"])
body, self.set_additional_outputs(body["webrtc_id"]), request, oauth_token
)

@server
Expand Down
10 changes: 7 additions & 3 deletions backend/fastrtc/webrtc_connection_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def _trigger_response(self, webrtc_id: str, args: list[Any] | None = None)
else:
return {"status": "failed", "meta": {"error": "not_a_reply_on_pause"}}

async def handle_offer(self, body, set_outputs):
async def handle_offer(self, body, set_outputs, request=None, oauth_token=None):
logger.debug("Starting to handle offer")
logger.debug("Offer body %s", body)

Expand Down Expand Up @@ -372,7 +372,9 @@ async def _():
def _(track):
relay = MediaRelay()
handler = self.handlers[body["webrtc_id"]]
context = Context(webrtc_id=body["webrtc_id"])
context = Context(
webrtc_id=body["webrtc_id"], oauth_token=oauth_token, request=request
)
if self.modality == "video" and track.kind == "video":
args = {}
handler_ = handler
Expand Down Expand Up @@ -426,7 +428,9 @@ def _(track):
elif self.mode == "send":
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())

context = Context(webrtc_id=body["webrtc_id"])
context = Context(
webrtc_id=body["webrtc_id"], oauth_token=oauth_token, request=request
)
if self.mode == "receive":
if self.modality == "video":
if isinstance(self.event_handler, VideoStreamHandler):
Expand Down
17 changes: 7 additions & 10 deletions demo/integrated_textbox/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
WebRTC,
WebRTCData,
WebRTCError,
get_current_context,
get_hf_turn_credentials,
get_stt_model,
)
Expand All @@ -28,13 +29,13 @@
def response(
data: WebRTCData,
conversation: list[dict],
token: str | None = None,
model: str = "meta-llama/Llama-3.2-3B-Instruct",
provider: str = "sambanova",
):
context = get_current_context()
print("conversation before", conversation)
if not provider.startswith("http") and not token:
raise WebRTCError("Please add your HF token.")
if not provider.startswith("http") and not context.oauth_token:
raise WebRTCError("Please Sign in to use this demo.")

if data.audio is not None and data.audio[1].size > 0:
user_audio_text = stt_model.stt(data.audio)
Expand All @@ -48,7 +49,7 @@ def response(
client = OpenAI(base_url=provider, api_key="ollama")
else:
client = huggingface_hub.InferenceClient(
api_key=token,
api_key=context.oauth_token.access_token, # type: ignore
provider=provider, # type: ignore
)

Expand Down Expand Up @@ -103,9 +104,6 @@ def hide_token(provider: str):
"""
)
with gr.Sidebar():
token = gr.Textbox(
placeholder="Place your HF token here", type="password", label="HF Token"
)
model = gr.Dropdown(
choices=["meta-llama/Llama-3.2-3B-Instruct"],
allow_custom_value=True,
Expand All @@ -114,11 +112,10 @@ def hide_token(provider: str):
provider = gr.Dropdown(
label="Provider",
choices=providers,
value="sambanova",
value="auto",
info="Select a hf-compatible provider or type the url of your server, e.g. http://127.0.0.1:11434/v1 for ollama",
allow_custom_value=True,
)
provider.change(hide_token, inputs=[provider], outputs=[token])
cb = gr.Chatbot(type="messages", height=600)
webrtc = WebRTC(
modality="audio",
Expand All @@ -131,7 +128,7 @@ def hide_token(provider: str):
)
webrtc.stream(
ReplyOnPause(response), # type: ignore
inputs=[webrtc, cb, token, model, provider],
inputs=[webrtc, cb, model, provider],
outputs=[cb],
concurrency_limit=100,
)
Expand Down
3 changes: 2 additions & 1 deletion frontend/shared/InteractiveAudio.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@
{/if}
<div
class="audio-container"
class:full-screen={full_screen || full_screen === null}
class:full-screen={(full_screen || full_screen === null) &&
variant !== "textbox"}
>
<audio
class="standard-player"
Expand Down