Sofia Casadei commited on
Commit
5ef2360
·
1 Parent(s): aec5df4

first version

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 1: Get uv installer
2
+ FROM ghcr.io/astral-sh/uv:0.2.12 as uv
3
+
4
+ # Stage 2: Main application image
5
+ FROM python:3.10.12-slim-bookworm
6
+
7
+ # Copy uv from first stage
8
+ COPY --from=uv /uv /uv
9
+
10
+ # Create virtual environment with uv
11
+ RUN --mount=type=cache,target=/root/.cache/uv \
12
+ /uv venv /opt/venv
13
+
14
+ # Set environment variables
15
+ ENV VIRTUAL_ENV=/opt/venv \
16
+ PATH="/opt/venv/bin:$PATH"
17
+
18
+ # Install system dependencies
19
+ RUN apt-get update && apt-get install -y \
20
+ portaudio19-dev \
21
+ && rm -rf /var/lib/apt/lists/*
22
+
23
+ # Create user and set permissions (required for HF Spaces)
24
+ RUN useradd -m -u 1000 user && \
25
+ chown -R user /opt/venv
26
+
27
+ # Switch to user context
28
+ USER user
29
+ WORKDIR /app
30
+
31
+ # Set home to user's home directory
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:$PATH \
34
+ HF_HOME=/home/user/.cache/huggingface
35
+
36
+ # Copy requirements first for caching
37
+ COPY --chown=user requirements.txt .
38
+
39
+ # Install Python packages with uv caching
40
+ RUN --mount=type=cache,target=/home/user/.cache/uv \
41
+ /uv pip install -r requirements.txt
42
+
43
+ # Copy application code
44
+ COPY --chown=user . .
45
+
46
+ # Expose FastRTC port (matches HF Spaces default)
47
+ EXPOSE 7860
48
+
49
+ # Start the application using uvicorn (FastAPI)
50
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -4,6 +4,7 @@ emoji: 🐢
4
  colorFrom: indigo
5
  colorTo: gray
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
 
4
  colorFrom: indigo
5
  colorTo: gray
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from dotenv import load_dotenv
7
+ from fastapi import FastAPI
8
+ from fastapi.responses import StreamingResponse, HTMLResponse
9
+ from fastrtc import (
10
+ AdditionalOutputs,
11
+ ReplyOnPause,
12
+ Stream,
13
+ AlgoOptions,
14
+ SileroVadOptions,
15
+ audio_to_bytes,
16
+ )
17
+ from transformers import (
18
+ AutoModelForSpeechSeq2Seq,
19
+ AutoProcessor,
20
+ pipeline,
21
+ )
22
+ from transformers.utils import is_flash_attn_2_available
23
+
24
+ from utils.logger_config import setup_logging
25
+ from utils.device import get_device, get_torch_and_np_dtypes
26
+ from utils.turn_server import get_rtc_credentials
27
+
28
+
29
+ load_dotenv()
30
+ setup_logging(level=logging.DEBUG)
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ device = get_device(force_cpu=False)
35
+ torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
36
+ logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
37
+
38
+
39
+ attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
40
+ logger.info(f"Using attention: {attention}")
41
+
42
+
43
+ model_id = "openai/whisper-large-v3-turbo"
44
+ logger.info(f"Loading Whisper model: {model_id}")
45
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch_dtype,
48
+ low_cpu_mem_usage=True,
49
+ use_safetensors=True,
50
+ attn_implementation=attention
51
+ )
52
+ model.to(device)
53
+
54
+ processor = AutoProcessor.from_pretrained(model_id)
55
+
56
+ transcribe_pipeline = pipeline(
57
+ task="automatic-speech-recognition",
58
+ model=model,
59
+ tokenizer=processor.tokenizer,
60
+ feature_extractor=processor.feature_extractor,
61
+ torch_dtype=torch_dtype,
62
+ device=device,
63
+ )
64
+
65
+ # Warm up the model with empty audio
66
+ logger.info("Warming up Whisper model with dummy input")
67
+ warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
68
+ transcribe_pipeline(warmup_audio)
69
+ logger.info("Model warmup complete")
70
+
71
+
72
+ async def transcribe(audio: tuple[int, np.ndarray]):
73
+ sample_rate, audio_array = audio
74
+ logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
75
+
76
+ outputs = transcribe_pipeline(
77
+ audio_to_bytes(audio),
78
+ chunk_length_s=3,
79
+ batch_size=1,
80
+ generate_kwargs={
81
+ 'task': 'transcribe',
82
+ 'language': 'english',
83
+ },
84
+ #return_timestamps="word"
85
+ )
86
+ yield AdditionalOutputs(outputs["text"].strip())
87
+
88
+
89
+ logger.info("Initializing FastRTC stream")
90
+ stream = Stream(
91
+ handler=ReplyOnPause(
92
+ transcribe,
93
+ algo_options=AlgoOptions(
94
+ # Duration in seconds of audio chunks (default 0.6)
95
+ audio_chunk_duration=0.6,
96
+ # If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
97
+ started_talking_threshold=0.2,
98
+ # If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
99
+ speech_threshold=0.1,
100
+ ),
101
+ model_options=SileroVadOptions(
102
+ # Threshold for what is considered speech (default 0.5)
103
+ threshold=0.5,
104
+ # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
105
+ min_speech_duration_ms=250,
106
+ # Max duration of speech chunks, longer will be split (default float('inf'))
107
+ max_speech_duration_s=30,
108
+ # Wait for ms at the end of each speech chunk before separating it (default 2000)
109
+ min_silence_duration_ms=2000,
110
+ # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
111
+ window_size_samples=1024,
112
+ # Final speech chunks are padded by speech_pad_ms each side (default 400)
113
+ speech_pad_ms=400,
114
+ ),
115
+ ),
116
+ # send-receive: bidirectional streaming (default)
117
+ # send: client to server only
118
+ # receive: server to client only
119
+ modality="audio",
120
+ mode="send",
121
+ additional_outputs=[
122
+ gr.Textbox(label="Transcript"),
123
+ ],
124
+ additional_outputs_handler=lambda current, new: current + " " + new,
125
+ rtc_configuration=get_rtc_credentials(provider="hf") if os.getenv("APP_MODE") == "deployed" else None
126
+ )
127
+
128
+ app = FastAPI()
129
+ stream.mount(app)
130
+
131
+ @app.get("/transcript")
132
+ def _(webrtc_id: str):
133
+ logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
134
+ async def output_stream():
135
+ try:
136
+ async for output in stream.output_stream(webrtc_id):
137
+ transcript = output.args[0]
138
+ logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
139
+ yield f"event: output\ndata: {transcript}\n\n"
140
+ except Exception as e:
141
+ logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
142
+ raise
143
+
144
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
145
+
146
+ @app.get("/")
147
+ async def root():
148
+ return HTMLResponse(content=open("static/index.html").read())
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.4.0
2
+ fastrtc==0.0.10
3
+ fastrtc[vad]
4
+ python-dotenv==1.0.1
5
+ transformers==4.49.0
6
+ torch==2.6.0
7
+ torchaudio==2.6.0
8
+ fastapi
9
+ uvicorn[standard]
static/client.js ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Global variables
2
+ let peerConnection = null;
3
+ let dataChannel = null;
4
+ let webrtcId = null;
5
+
6
+ // Helper function to generate unique ID
7
+ function generateUniqueId() {
8
+ return Math.random().toString(36).substring(7);
9
+ }
10
+
11
+ // Update UI status
12
+ function updateStatus(connected) {
13
+ const statusDiv = document.getElementById('status');
14
+ const connectBtn = document.getElementById('connectBtn');
15
+ const disconnectBtn = document.getElementById('disconnectBtn');
16
+
17
+ statusDiv.textContent = connected ? 'Connected' : 'Disconnected';
18
+ statusDiv.className = connected ? 'connected' : 'disconnected';
19
+ connectBtn.disabled = connected;
20
+ disconnectBtn.disabled = !connected;
21
+ }
22
+
23
+ // Setup WebRTC connection
24
+ async function setupWebRTC() {
25
+ try {
26
+ // Create peer connection
27
+ peerConnection = new RTCPeerConnection();
28
+ webrtcId = generateUniqueId();
29
+
30
+ // Get audio stream from microphone
31
+ const stream = await navigator.mediaDevices.getUserMedia({
32
+ audio: true
33
+ });
34
+
35
+ // Add audio stream to peer connection
36
+ stream.getTracks().forEach(track => {
37
+ peerConnection.addTrack(track, stream);
38
+ });
39
+
40
+ // Create data channel
41
+ dataChannel = peerConnection.createDataChannel("text");
42
+
43
+ // Handle data channel messages
44
+ dataChannel.onmessage = (event) => {
45
+ const message = JSON.parse(event.data);
46
+ console.log("Received message:", message);
47
+
48
+ // Handle different message types
49
+ switch(message.type) {
50
+ case 'log':
51
+ console.log("Server log:", message.data);
52
+ break;
53
+ case 'error':
54
+ console.error("Server error:", message.data);
55
+ break;
56
+ case 'warning':
57
+ console.warn("Server warning:", message.data);
58
+ break;
59
+ }
60
+ };
61
+
62
+ // Create and send offer
63
+ const offer = await peerConnection.createOffer();
64
+ await peerConnection.setLocalDescription(offer);
65
+
66
+ // Send offer to server
67
+ const response = await fetch('/webrtc/offer', {
68
+ method: 'POST',
69
+ headers: { 'Content-Type': 'application/json' },
70
+ body: JSON.stringify({
71
+ sdp: offer.sdp,
72
+ type: offer.type,
73
+ webrtc_id: webrtcId
74
+ })
75
+ });
76
+
77
+ if (!response.ok) {
78
+ throw new Error(`HTTP error! status: ${response.status}`);
79
+ }
80
+
81
+ // Handle server response
82
+ const serverResponse = await response.json();
83
+
84
+ // Check for error response
85
+ if (serverResponse.status === 'failed') {
86
+ throw new Error(serverResponse.meta.error);
87
+ }
88
+
89
+ // Set remote description
90
+ await peerConnection.setRemoteDescription(serverResponse);
91
+
92
+ // Update UI
93
+ updateStatus(true);
94
+
95
+ // Add to setupWebRTC():
96
+ const eventSource = new EventSource(`/transcript?webrtc_id=${webrtcId}`);
97
+
98
+ eventSource.onmessage = (event) => {
99
+ const transcriptDiv = document.getElementById('transcript');
100
+ transcriptDiv.innerHTML += `<p>${event.data}</p>`;
101
+ };
102
+
103
+ } catch (error) {
104
+ console.error("Error setting up WebRTC:", error);
105
+ updateStatus(false);
106
+ }
107
+ }
108
+
109
+ // Cleanup function
110
+ function disconnect() {
111
+ if (peerConnection) {
112
+ peerConnection.close();
113
+ peerConnection = null;
114
+ }
115
+ if (dataChannel) {
116
+ dataChannel.close();
117
+ dataChannel = null;
118
+ }
119
+ webrtcId = null;
120
+ updateStatus(false);
121
+ }
122
+
123
+ // Add event listeners when page loads
124
+ document.addEventListener('DOMContentLoaded', () => {
125
+ document.getElementById('connectBtn').addEventListener('click', setupWebRTC);
126
+ document.getElementById('disconnectBtn').addEventListener('click', disconnect);
127
+ });
static/index.html ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>FastRTC Audio Client</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ max-width: 800px;
11
+ margin: 0 auto;
12
+ padding: 20px;
13
+ }
14
+ .controls {
15
+ margin: 20px 0;
16
+ }
17
+ button {
18
+ padding: 10px 20px;
19
+ margin: 5px;
20
+ }
21
+ #status {
22
+ margin: 10px 0;
23
+ padding: 10px;
24
+ border-radius: 4px;
25
+ }
26
+ .connected {
27
+ background-color: #d4edda;
28
+ color: #155724;
29
+ }
30
+ .disconnected {
31
+ background-color: #f8d7da;
32
+ color: #721c24;
33
+ }
34
+ </style>
35
+ </head>
36
+ <body>
37
+ <h1>FastRTC Audio Client</h1>
38
+ <div id="status" class="disconnected">Disconnected</div>
39
+
40
+ <div class="controls">
41
+ <button id="connectBtn">Connect</button>
42
+ <button id="disconnectBtn" disabled>Disconnect</button>
43
+ </div>
44
+
45
+ <!-- Audio element for playback -->
46
+ <audio id="audioOutput" autoplay></audio>
47
+
48
+ <div id="transcript" style="margin-top: 20px; padding: 10px; border: 1px solid #ccc;"></div>
49
+
50
+ <!-- Load our WebRTC client code -->
51
+ <script src="/static/client.js"></script>
52
+ </body>
53
+ </html>
utils/__init__.py ADDED
File without changes
utils/device.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def get_device(force_cpu=False):
5
+ if force_cpu:
6
+ return "cpu"
7
+ if torch.cuda.is_available():
8
+ return "cuda"
9
+ elif torch.backends.mps.is_available():
10
+ torch.mps.empty_cache()
11
+ return "mps"
12
+ else:
13
+ return "cpu"
14
+
15
+ def get_torch_and_np_dtypes(device, use_bfloat16=False):
16
+ if device == "cuda":
17
+ torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
18
+ np_dtype = np.float16
19
+ elif device == "mps":
20
+ torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
21
+ np_dtype = np.float16
22
+ else:
23
+ torch_dtype = torch.float32
24
+ np_dtype = np.float32
25
+ return torch_dtype, np_dtype
utils/logger_config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import os
4
+
5
+ LOGS_DIR = "logs"
6
+
7
+ class ColorFormatter(logging.Formatter):
8
+ """Custom formatter that adds colors to log levels"""
9
+
10
+ grey = "\x1b[38;20m"
11
+ yellow = "\x1b[33;20m"
12
+ red = "\x1b[31;20m"
13
+ bold_red = "\x1b[31;1m"
14
+ blue = "\x1b[34;20m"
15
+ green = "\x1b[32;20m"
16
+ reset = "\x1b[0m"
17
+
18
+ format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
+
20
+ FORMATS = {
21
+ logging.DEBUG: blue + format_str + reset,
22
+ logging.INFO: green + format_str + reset,
23
+ logging.WARNING: yellow + format_str + reset,
24
+ logging.ERROR: red + format_str + reset,
25
+ logging.CRITICAL: bold_red + format_str + reset
26
+ }
27
+
28
+ def format(self, record):
29
+ log_fmt = self.FORMATS.get(record.levelno)
30
+ formatter = logging.Formatter(log_fmt, datefmt='%Y-%m-%d %H:%M:%S')
31
+ return formatter.format(record)
32
+
33
+ def configure_logfire():
34
+ import logfire
35
+ # First run `logfire auth`
36
+ # -> Your Logfire credentials are stored in <path>/.logfire/default.toml
37
+
38
+ def scrubbing_callback(m: logfire.ScrubMatch):
39
+ if m.pattern_match.group(0) == 'Credit Card':
40
+ return m.value
41
+
42
+ logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback))
43
+
44
+ def setup_logging(level=None, with_logfire=False):
45
+ """Configure logging for the entire application"""
46
+ if with_logfire:
47
+ configure_logfire()
48
+
49
+ # Get level from environment variable or use default
50
+ if level is None:
51
+ level_name = os.getenv('LOG_LEVEL', 'INFO')
52
+ level = getattr(logging, level_name.upper(), logging.INFO)
53
+
54
+ # Configure stream handler (console output) with color formatter
55
+ stream_handler = logging.StreamHandler(sys.stdout)
56
+ stream_handler.setFormatter(ColorFormatter())
57
+
58
+ # Configure root logger
59
+ root_logger = logging.getLogger()
60
+ root_logger.setLevel(level)
61
+
62
+ # Remove existing handlers
63
+ root_logger.handlers = []
64
+ root_logger.addHandler(stream_handler)
65
+
66
+ # Prevent duplicate logging
67
+ root_logger.propagate = False
68
+
69
+ # Optionally configure file handler
70
+ os.makedirs(LOGS_DIR, exist_ok=True)
71
+ file_handler = logging.FileHandler(os.path.join(LOGS_DIR, 'app.log'))
72
+ file_handler.setFormatter(logging.Formatter(
73
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
74
+ datefmt='%Y-%m-%d %H:%M:%S'
75
+ ))
76
+ root_logger.addHandler(file_handler)
77
+
78
+ # Get comma-separated list of loggers to suppress from env
79
+ suppress_loggers = os.getenv('SUPPRESS_LOGGERS', '').strip()
80
+ if suppress_loggers:
81
+ for logger_name in suppress_loggers.split(','):
82
+ logger_name = logger_name.strip()
83
+ if logger_name:
84
+ logging.getLogger(logger_name).setLevel(logging.WARNING)
85
+
86
+ logging.info(f"Logging configured with level: {logging.getLevelName(level)}")
utils/turn_server.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal, Optional, Dict, Any
3
+ import requests
4
+
5
+ from fastrtc import get_hf_turn_credentials, get_twilio_turn_credentials
6
+
7
+
8
+ def get_rtc_credentials(
9
+ provider: Literal["hf", "twilio", "cloudflare"] = "hf",
10
+ **kwargs
11
+ ) -> Dict[str, Any]:
12
+ """
13
+ Get RTC configuration for different TURN server providers.
14
+
15
+ Args:
16
+ provider: The TURN server provider to use ('hf', 'twilio', or 'cloudflare')
17
+ **kwargs: Additional arguments passed to the specific provider's function
18
+
19
+ Returns:
20
+ Dictionary containing the RTC configuration
21
+ """
22
+ try:
23
+ if provider == "hf":
24
+ return get_hf_credentials(**kwargs)
25
+ elif provider == "twilio":
26
+ return get_twilio_credentials(**kwargs)
27
+ elif provider == "cloudflare":
28
+ return get_cloudflare_credentials(**kwargs)
29
+ except Exception as e:
30
+ raise Exception(f"Failed to get RTC credentials ({provider}): {str(e)}")
31
+
32
+
33
+ def get_hf_credentials(token: Optional[str] = None) -> Dict[str, Any]:
34
+ """
35
+ Get credentials for Hugging Face's community TURN server.
36
+
37
+ Required setup:
38
+ 1. Create a Hugging Face account at huggingface.co
39
+ 2. Visit: https://huggingface.co/spaces/fastrtc/turn-server-login
40
+ 3. Set HF_TOKEN environment variable or pass token directly
41
+ """
42
+ token = token or os.environ.get("HF_TOKEN")
43
+ if not token:
44
+ raise ValueError("HF_TOKEN environment variable not set")
45
+
46
+ try:
47
+ return get_hf_turn_credentials(token=token)
48
+ except Exception as e:
49
+ raise Exception(f"Failed to get HF TURN credentials: {str(e)}")
50
+
51
+
52
+ def get_twilio_credentials(
53
+ account_sid: Optional[str] = None,
54
+ auth_token: Optional[str] = None
55
+ ) -> Dict[str, Any]:
56
+ """
57
+ Get credentials for Twilio's TURN server.
58
+
59
+ Required setup:
60
+ 1. Create a free Twilio account at: https://login.twilio.com/u/signup
61
+ 2. Get your Account SID and Auth Token from the Twilio Console
62
+ 3. Set environment variables:
63
+ - TWILIO_ACCOUNT_SID (or pass directly)
64
+ - TWILIO_AUTH_TOKEN (or pass directly)
65
+ """
66
+ account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID")
67
+ auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN")
68
+
69
+ if not account_sid or not auth_token:
70
+ raise ValueError("Twilio credentials not found. Set TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN env vars")
71
+
72
+ try:
73
+ return get_twilio_turn_credentials(account_sid=account_sid, auth_token=auth_token)
74
+ except Exception as e:
75
+ raise Exception(f"Failed to get Twilio TURN credentials: {str(e)}")
76
+
77
+
78
+ def get_cloudflare_credentials(
79
+ key_id: Optional[str] = None,
80
+ api_token: Optional[str] = None,
81
+ ttl: int = 86400
82
+ ) -> Dict[str, Any]:
83
+ """
84
+ Get credentials for Cloudflare's TURN server.
85
+
86
+ Required setup:
87
+ 1. Create a free Cloudflare account
88
+ 2. Go to Cloudflare dashboard -> Calls section
89
+ 3. Create a TURN App and get the Turn Token ID and API Token
90
+ 4. Set environment variables:
91
+ - TURN_KEY_ID
92
+ - TURN_KEY_API_TOKEN
93
+
94
+ Args:
95
+ key_id: Cloudflare Turn Token ID (optional, will use env var if not provided)
96
+ api_token: Cloudflare API Token (optional, will use env var if not provided)
97
+ ttl: Time-to-live for credentials in seconds (default: 24 hours)
98
+ """
99
+ key_id = key_id or os.environ.get("TURN_KEY_ID")
100
+ api_token = api_token or os.environ.get("TURN_KEY_API_TOKEN")
101
+
102
+ if not key_id or not api_token:
103
+ raise ValueError("Cloudflare credentials not found. Set TURN_KEY_ID and TURN_KEY_API_TOKEN env vars")
104
+
105
+ response = requests.post(
106
+ f"https://rtc.live.cloudflare.com/v1/turn/keys/{key_id}/credentials/generate",
107
+ headers={
108
+ "Authorization": f"Bearer {api_token}",
109
+ "Content-Type": "application/json",
110
+ },
111
+ json={"ttl": ttl},
112
+ )
113
+
114
+ if response.ok:
115
+ return {"iceServers": [response.json()["iceServers"]]}
116
+ else:
117
+ raise Exception(
118
+ f"Failed to get Cloudflare TURN credentials: {response.status_code} {response.text}"
119
+ )