Spaces:
Sleeping
Sleeping
Sofia Casadei
commited on
Commit
·
5ef2360
1
Parent(s):
aec5df4
first version
Browse files- .gitignore +1 -0
- Dockerfile +50 -0
- README.md +1 -0
- app.py +148 -0
- requirements.txt +9 -0
- static/client.js +127 -0
- static/index.html +53 -0
- utils/__init__.py +0 -0
- utils/device.py +25 -0
- utils/logger_config.py +86 -0
- utils/turn_server.py +119 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
Dockerfile
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stage 1: Get uv installer
|
2 |
+
FROM ghcr.io/astral-sh/uv:0.2.12 as uv
|
3 |
+
|
4 |
+
# Stage 2: Main application image
|
5 |
+
FROM python:3.10.12-slim-bookworm
|
6 |
+
|
7 |
+
# Copy uv from first stage
|
8 |
+
COPY --from=uv /uv /uv
|
9 |
+
|
10 |
+
# Create virtual environment with uv
|
11 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
12 |
+
/uv venv /opt/venv
|
13 |
+
|
14 |
+
# Set environment variables
|
15 |
+
ENV VIRTUAL_ENV=/opt/venv \
|
16 |
+
PATH="/opt/venv/bin:$PATH"
|
17 |
+
|
18 |
+
# Install system dependencies
|
19 |
+
RUN apt-get update && apt-get install -y \
|
20 |
+
portaudio19-dev \
|
21 |
+
&& rm -rf /var/lib/apt/lists/*
|
22 |
+
|
23 |
+
# Create user and set permissions (required for HF Spaces)
|
24 |
+
RUN useradd -m -u 1000 user && \
|
25 |
+
chown -R user /opt/venv
|
26 |
+
|
27 |
+
# Switch to user context
|
28 |
+
USER user
|
29 |
+
WORKDIR /app
|
30 |
+
|
31 |
+
# Set home to user's home directory
|
32 |
+
ENV HOME=/home/user \
|
33 |
+
PATH=/home/user/.local/bin:$PATH \
|
34 |
+
HF_HOME=/home/user/.cache/huggingface
|
35 |
+
|
36 |
+
# Copy requirements first for caching
|
37 |
+
COPY --chown=user requirements.txt .
|
38 |
+
|
39 |
+
# Install Python packages with uv caching
|
40 |
+
RUN --mount=type=cache,target=/home/user/.cache/uv \
|
41 |
+
/uv pip install -r requirements.txt
|
42 |
+
|
43 |
+
# Copy application code
|
44 |
+
COPY --chown=user . .
|
45 |
+
|
46 |
+
# Expose FastRTC port (matches HF Spaces default)
|
47 |
+
EXPOSE 7860
|
48 |
+
|
49 |
+
# Start the application using uvicorn (FastAPI)
|
50 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -4,6 +4,7 @@ emoji: 🐢
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: gray
|
6 |
sdk: docker
|
|
|
7 |
pinned: false
|
8 |
---
|
9 |
|
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: gray
|
6 |
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
pinned: false
|
9 |
---
|
10 |
|
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from fastapi import FastAPI
|
8 |
+
from fastapi.responses import StreamingResponse, HTMLResponse
|
9 |
+
from fastrtc import (
|
10 |
+
AdditionalOutputs,
|
11 |
+
ReplyOnPause,
|
12 |
+
Stream,
|
13 |
+
AlgoOptions,
|
14 |
+
SileroVadOptions,
|
15 |
+
audio_to_bytes,
|
16 |
+
)
|
17 |
+
from transformers import (
|
18 |
+
AutoModelForSpeechSeq2Seq,
|
19 |
+
AutoProcessor,
|
20 |
+
pipeline,
|
21 |
+
)
|
22 |
+
from transformers.utils import is_flash_attn_2_available
|
23 |
+
|
24 |
+
from utils.logger_config import setup_logging
|
25 |
+
from utils.device import get_device, get_torch_and_np_dtypes
|
26 |
+
from utils.turn_server import get_rtc_credentials
|
27 |
+
|
28 |
+
|
29 |
+
load_dotenv()
|
30 |
+
setup_logging(level=logging.DEBUG)
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
device = get_device(force_cpu=False)
|
35 |
+
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
|
36 |
+
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
|
37 |
+
|
38 |
+
|
39 |
+
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
|
40 |
+
logger.info(f"Using attention: {attention}")
|
41 |
+
|
42 |
+
|
43 |
+
model_id = "openai/whisper-large-v3-turbo"
|
44 |
+
logger.info(f"Loading Whisper model: {model_id}")
|
45 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
46 |
+
model_id,
|
47 |
+
torch_dtype=torch_dtype,
|
48 |
+
low_cpu_mem_usage=True,
|
49 |
+
use_safetensors=True,
|
50 |
+
attn_implementation=attention
|
51 |
+
)
|
52 |
+
model.to(device)
|
53 |
+
|
54 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
55 |
+
|
56 |
+
transcribe_pipeline = pipeline(
|
57 |
+
task="automatic-speech-recognition",
|
58 |
+
model=model,
|
59 |
+
tokenizer=processor.tokenizer,
|
60 |
+
feature_extractor=processor.feature_extractor,
|
61 |
+
torch_dtype=torch_dtype,
|
62 |
+
device=device,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Warm up the model with empty audio
|
66 |
+
logger.info("Warming up Whisper model with dummy input")
|
67 |
+
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
|
68 |
+
transcribe_pipeline(warmup_audio)
|
69 |
+
logger.info("Model warmup complete")
|
70 |
+
|
71 |
+
|
72 |
+
async def transcribe(audio: tuple[int, np.ndarray]):
|
73 |
+
sample_rate, audio_array = audio
|
74 |
+
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
|
75 |
+
|
76 |
+
outputs = transcribe_pipeline(
|
77 |
+
audio_to_bytes(audio),
|
78 |
+
chunk_length_s=3,
|
79 |
+
batch_size=1,
|
80 |
+
generate_kwargs={
|
81 |
+
'task': 'transcribe',
|
82 |
+
'language': 'english',
|
83 |
+
},
|
84 |
+
#return_timestamps="word"
|
85 |
+
)
|
86 |
+
yield AdditionalOutputs(outputs["text"].strip())
|
87 |
+
|
88 |
+
|
89 |
+
logger.info("Initializing FastRTC stream")
|
90 |
+
stream = Stream(
|
91 |
+
handler=ReplyOnPause(
|
92 |
+
transcribe,
|
93 |
+
algo_options=AlgoOptions(
|
94 |
+
# Duration in seconds of audio chunks (default 0.6)
|
95 |
+
audio_chunk_duration=0.6,
|
96 |
+
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
|
97 |
+
started_talking_threshold=0.2,
|
98 |
+
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
|
99 |
+
speech_threshold=0.1,
|
100 |
+
),
|
101 |
+
model_options=SileroVadOptions(
|
102 |
+
# Threshold for what is considered speech (default 0.5)
|
103 |
+
threshold=0.5,
|
104 |
+
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
|
105 |
+
min_speech_duration_ms=250,
|
106 |
+
# Max duration of speech chunks, longer will be split (default float('inf'))
|
107 |
+
max_speech_duration_s=30,
|
108 |
+
# Wait for ms at the end of each speech chunk before separating it (default 2000)
|
109 |
+
min_silence_duration_ms=2000,
|
110 |
+
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
|
111 |
+
window_size_samples=1024,
|
112 |
+
# Final speech chunks are padded by speech_pad_ms each side (default 400)
|
113 |
+
speech_pad_ms=400,
|
114 |
+
),
|
115 |
+
),
|
116 |
+
# send-receive: bidirectional streaming (default)
|
117 |
+
# send: client to server only
|
118 |
+
# receive: server to client only
|
119 |
+
modality="audio",
|
120 |
+
mode="send",
|
121 |
+
additional_outputs=[
|
122 |
+
gr.Textbox(label="Transcript"),
|
123 |
+
],
|
124 |
+
additional_outputs_handler=lambda current, new: current + " " + new,
|
125 |
+
rtc_configuration=get_rtc_credentials(provider="hf") if os.getenv("APP_MODE") == "deployed" else None
|
126 |
+
)
|
127 |
+
|
128 |
+
app = FastAPI()
|
129 |
+
stream.mount(app)
|
130 |
+
|
131 |
+
@app.get("/transcript")
|
132 |
+
def _(webrtc_id: str):
|
133 |
+
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
|
134 |
+
async def output_stream():
|
135 |
+
try:
|
136 |
+
async for output in stream.output_stream(webrtc_id):
|
137 |
+
transcript = output.args[0]
|
138 |
+
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
|
139 |
+
yield f"event: output\ndata: {transcript}\n\n"
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
|
142 |
+
raise
|
143 |
+
|
144 |
+
return StreamingResponse(output_stream(), media_type="text/event-stream")
|
145 |
+
|
146 |
+
@app.get("/")
|
147 |
+
async def root():
|
148 |
+
return HTMLResponse(content=open("static/index.html").read())
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.4.0
|
2 |
+
fastrtc==0.0.10
|
3 |
+
fastrtc[vad]
|
4 |
+
python-dotenv==1.0.1
|
5 |
+
transformers==4.49.0
|
6 |
+
torch==2.6.0
|
7 |
+
torchaudio==2.6.0
|
8 |
+
fastapi
|
9 |
+
uvicorn[standard]
|
static/client.js
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Global variables
|
2 |
+
let peerConnection = null;
|
3 |
+
let dataChannel = null;
|
4 |
+
let webrtcId = null;
|
5 |
+
|
6 |
+
// Helper function to generate unique ID
|
7 |
+
function generateUniqueId() {
|
8 |
+
return Math.random().toString(36).substring(7);
|
9 |
+
}
|
10 |
+
|
11 |
+
// Update UI status
|
12 |
+
function updateStatus(connected) {
|
13 |
+
const statusDiv = document.getElementById('status');
|
14 |
+
const connectBtn = document.getElementById('connectBtn');
|
15 |
+
const disconnectBtn = document.getElementById('disconnectBtn');
|
16 |
+
|
17 |
+
statusDiv.textContent = connected ? 'Connected' : 'Disconnected';
|
18 |
+
statusDiv.className = connected ? 'connected' : 'disconnected';
|
19 |
+
connectBtn.disabled = connected;
|
20 |
+
disconnectBtn.disabled = !connected;
|
21 |
+
}
|
22 |
+
|
23 |
+
// Setup WebRTC connection
|
24 |
+
async function setupWebRTC() {
|
25 |
+
try {
|
26 |
+
// Create peer connection
|
27 |
+
peerConnection = new RTCPeerConnection();
|
28 |
+
webrtcId = generateUniqueId();
|
29 |
+
|
30 |
+
// Get audio stream from microphone
|
31 |
+
const stream = await navigator.mediaDevices.getUserMedia({
|
32 |
+
audio: true
|
33 |
+
});
|
34 |
+
|
35 |
+
// Add audio stream to peer connection
|
36 |
+
stream.getTracks().forEach(track => {
|
37 |
+
peerConnection.addTrack(track, stream);
|
38 |
+
});
|
39 |
+
|
40 |
+
// Create data channel
|
41 |
+
dataChannel = peerConnection.createDataChannel("text");
|
42 |
+
|
43 |
+
// Handle data channel messages
|
44 |
+
dataChannel.onmessage = (event) => {
|
45 |
+
const message = JSON.parse(event.data);
|
46 |
+
console.log("Received message:", message);
|
47 |
+
|
48 |
+
// Handle different message types
|
49 |
+
switch(message.type) {
|
50 |
+
case 'log':
|
51 |
+
console.log("Server log:", message.data);
|
52 |
+
break;
|
53 |
+
case 'error':
|
54 |
+
console.error("Server error:", message.data);
|
55 |
+
break;
|
56 |
+
case 'warning':
|
57 |
+
console.warn("Server warning:", message.data);
|
58 |
+
break;
|
59 |
+
}
|
60 |
+
};
|
61 |
+
|
62 |
+
// Create and send offer
|
63 |
+
const offer = await peerConnection.createOffer();
|
64 |
+
await peerConnection.setLocalDescription(offer);
|
65 |
+
|
66 |
+
// Send offer to server
|
67 |
+
const response = await fetch('/webrtc/offer', {
|
68 |
+
method: 'POST',
|
69 |
+
headers: { 'Content-Type': 'application/json' },
|
70 |
+
body: JSON.stringify({
|
71 |
+
sdp: offer.sdp,
|
72 |
+
type: offer.type,
|
73 |
+
webrtc_id: webrtcId
|
74 |
+
})
|
75 |
+
});
|
76 |
+
|
77 |
+
if (!response.ok) {
|
78 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
79 |
+
}
|
80 |
+
|
81 |
+
// Handle server response
|
82 |
+
const serverResponse = await response.json();
|
83 |
+
|
84 |
+
// Check for error response
|
85 |
+
if (serverResponse.status === 'failed') {
|
86 |
+
throw new Error(serverResponse.meta.error);
|
87 |
+
}
|
88 |
+
|
89 |
+
// Set remote description
|
90 |
+
await peerConnection.setRemoteDescription(serverResponse);
|
91 |
+
|
92 |
+
// Update UI
|
93 |
+
updateStatus(true);
|
94 |
+
|
95 |
+
// Add to setupWebRTC():
|
96 |
+
const eventSource = new EventSource(`/transcript?webrtc_id=${webrtcId}`);
|
97 |
+
|
98 |
+
eventSource.onmessage = (event) => {
|
99 |
+
const transcriptDiv = document.getElementById('transcript');
|
100 |
+
transcriptDiv.innerHTML += `<p>${event.data}</p>`;
|
101 |
+
};
|
102 |
+
|
103 |
+
} catch (error) {
|
104 |
+
console.error("Error setting up WebRTC:", error);
|
105 |
+
updateStatus(false);
|
106 |
+
}
|
107 |
+
}
|
108 |
+
|
109 |
+
// Cleanup function
|
110 |
+
function disconnect() {
|
111 |
+
if (peerConnection) {
|
112 |
+
peerConnection.close();
|
113 |
+
peerConnection = null;
|
114 |
+
}
|
115 |
+
if (dataChannel) {
|
116 |
+
dataChannel.close();
|
117 |
+
dataChannel = null;
|
118 |
+
}
|
119 |
+
webrtcId = null;
|
120 |
+
updateStatus(false);
|
121 |
+
}
|
122 |
+
|
123 |
+
// Add event listeners when page loads
|
124 |
+
document.addEventListener('DOMContentLoaded', () => {
|
125 |
+
document.getElementById('connectBtn').addEventListener('click', setupWebRTC);
|
126 |
+
document.getElementById('disconnectBtn').addEventListener('click', disconnect);
|
127 |
+
});
|
static/index.html
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>FastRTC Audio Client</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
max-width: 800px;
|
11 |
+
margin: 0 auto;
|
12 |
+
padding: 20px;
|
13 |
+
}
|
14 |
+
.controls {
|
15 |
+
margin: 20px 0;
|
16 |
+
}
|
17 |
+
button {
|
18 |
+
padding: 10px 20px;
|
19 |
+
margin: 5px;
|
20 |
+
}
|
21 |
+
#status {
|
22 |
+
margin: 10px 0;
|
23 |
+
padding: 10px;
|
24 |
+
border-radius: 4px;
|
25 |
+
}
|
26 |
+
.connected {
|
27 |
+
background-color: #d4edda;
|
28 |
+
color: #155724;
|
29 |
+
}
|
30 |
+
.disconnected {
|
31 |
+
background-color: #f8d7da;
|
32 |
+
color: #721c24;
|
33 |
+
}
|
34 |
+
</style>
|
35 |
+
</head>
|
36 |
+
<body>
|
37 |
+
<h1>FastRTC Audio Client</h1>
|
38 |
+
<div id="status" class="disconnected">Disconnected</div>
|
39 |
+
|
40 |
+
<div class="controls">
|
41 |
+
<button id="connectBtn">Connect</button>
|
42 |
+
<button id="disconnectBtn" disabled>Disconnect</button>
|
43 |
+
</div>
|
44 |
+
|
45 |
+
<!-- Audio element for playback -->
|
46 |
+
<audio id="audioOutput" autoplay></audio>
|
47 |
+
|
48 |
+
<div id="transcript" style="margin-top: 20px; padding: 10px; border: 1px solid #ccc;"></div>
|
49 |
+
|
50 |
+
<!-- Load our WebRTC client code -->
|
51 |
+
<script src="/static/client.js"></script>
|
52 |
+
</body>
|
53 |
+
</html>
|
utils/__init__.py
ADDED
File without changes
|
utils/device.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_device(force_cpu=False):
|
5 |
+
if force_cpu:
|
6 |
+
return "cpu"
|
7 |
+
if torch.cuda.is_available():
|
8 |
+
return "cuda"
|
9 |
+
elif torch.backends.mps.is_available():
|
10 |
+
torch.mps.empty_cache()
|
11 |
+
return "mps"
|
12 |
+
else:
|
13 |
+
return "cpu"
|
14 |
+
|
15 |
+
def get_torch_and_np_dtypes(device, use_bfloat16=False):
|
16 |
+
if device == "cuda":
|
17 |
+
torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
|
18 |
+
np_dtype = np.float16
|
19 |
+
elif device == "mps":
|
20 |
+
torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
|
21 |
+
np_dtype = np.float16
|
22 |
+
else:
|
23 |
+
torch_dtype = torch.float32
|
24 |
+
np_dtype = np.float32
|
25 |
+
return torch_dtype, np_dtype
|
utils/logger_config.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
LOGS_DIR = "logs"
|
6 |
+
|
7 |
+
class ColorFormatter(logging.Formatter):
|
8 |
+
"""Custom formatter that adds colors to log levels"""
|
9 |
+
|
10 |
+
grey = "\x1b[38;20m"
|
11 |
+
yellow = "\x1b[33;20m"
|
12 |
+
red = "\x1b[31;20m"
|
13 |
+
bold_red = "\x1b[31;1m"
|
14 |
+
blue = "\x1b[34;20m"
|
15 |
+
green = "\x1b[32;20m"
|
16 |
+
reset = "\x1b[0m"
|
17 |
+
|
18 |
+
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
19 |
+
|
20 |
+
FORMATS = {
|
21 |
+
logging.DEBUG: blue + format_str + reset,
|
22 |
+
logging.INFO: green + format_str + reset,
|
23 |
+
logging.WARNING: yellow + format_str + reset,
|
24 |
+
logging.ERROR: red + format_str + reset,
|
25 |
+
logging.CRITICAL: bold_red + format_str + reset
|
26 |
+
}
|
27 |
+
|
28 |
+
def format(self, record):
|
29 |
+
log_fmt = self.FORMATS.get(record.levelno)
|
30 |
+
formatter = logging.Formatter(log_fmt, datefmt='%Y-%m-%d %H:%M:%S')
|
31 |
+
return formatter.format(record)
|
32 |
+
|
33 |
+
def configure_logfire():
|
34 |
+
import logfire
|
35 |
+
# First run `logfire auth`
|
36 |
+
# -> Your Logfire credentials are stored in <path>/.logfire/default.toml
|
37 |
+
|
38 |
+
def scrubbing_callback(m: logfire.ScrubMatch):
|
39 |
+
if m.pattern_match.group(0) == 'Credit Card':
|
40 |
+
return m.value
|
41 |
+
|
42 |
+
logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback))
|
43 |
+
|
44 |
+
def setup_logging(level=None, with_logfire=False):
|
45 |
+
"""Configure logging for the entire application"""
|
46 |
+
if with_logfire:
|
47 |
+
configure_logfire()
|
48 |
+
|
49 |
+
# Get level from environment variable or use default
|
50 |
+
if level is None:
|
51 |
+
level_name = os.getenv('LOG_LEVEL', 'INFO')
|
52 |
+
level = getattr(logging, level_name.upper(), logging.INFO)
|
53 |
+
|
54 |
+
# Configure stream handler (console output) with color formatter
|
55 |
+
stream_handler = logging.StreamHandler(sys.stdout)
|
56 |
+
stream_handler.setFormatter(ColorFormatter())
|
57 |
+
|
58 |
+
# Configure root logger
|
59 |
+
root_logger = logging.getLogger()
|
60 |
+
root_logger.setLevel(level)
|
61 |
+
|
62 |
+
# Remove existing handlers
|
63 |
+
root_logger.handlers = []
|
64 |
+
root_logger.addHandler(stream_handler)
|
65 |
+
|
66 |
+
# Prevent duplicate logging
|
67 |
+
root_logger.propagate = False
|
68 |
+
|
69 |
+
# Optionally configure file handler
|
70 |
+
os.makedirs(LOGS_DIR, exist_ok=True)
|
71 |
+
file_handler = logging.FileHandler(os.path.join(LOGS_DIR, 'app.log'))
|
72 |
+
file_handler.setFormatter(logging.Formatter(
|
73 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
74 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
75 |
+
))
|
76 |
+
root_logger.addHandler(file_handler)
|
77 |
+
|
78 |
+
# Get comma-separated list of loggers to suppress from env
|
79 |
+
suppress_loggers = os.getenv('SUPPRESS_LOGGERS', '').strip()
|
80 |
+
if suppress_loggers:
|
81 |
+
for logger_name in suppress_loggers.split(','):
|
82 |
+
logger_name = logger_name.strip()
|
83 |
+
if logger_name:
|
84 |
+
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
85 |
+
|
86 |
+
logging.info(f"Logging configured with level: {logging.getLevelName(level)}")
|
utils/turn_server.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Literal, Optional, Dict, Any
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from fastrtc import get_hf_turn_credentials, get_twilio_turn_credentials
|
6 |
+
|
7 |
+
|
8 |
+
def get_rtc_credentials(
|
9 |
+
provider: Literal["hf", "twilio", "cloudflare"] = "hf",
|
10 |
+
**kwargs
|
11 |
+
) -> Dict[str, Any]:
|
12 |
+
"""
|
13 |
+
Get RTC configuration for different TURN server providers.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
provider: The TURN server provider to use ('hf', 'twilio', or 'cloudflare')
|
17 |
+
**kwargs: Additional arguments passed to the specific provider's function
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Dictionary containing the RTC configuration
|
21 |
+
"""
|
22 |
+
try:
|
23 |
+
if provider == "hf":
|
24 |
+
return get_hf_credentials(**kwargs)
|
25 |
+
elif provider == "twilio":
|
26 |
+
return get_twilio_credentials(**kwargs)
|
27 |
+
elif provider == "cloudflare":
|
28 |
+
return get_cloudflare_credentials(**kwargs)
|
29 |
+
except Exception as e:
|
30 |
+
raise Exception(f"Failed to get RTC credentials ({provider}): {str(e)}")
|
31 |
+
|
32 |
+
|
33 |
+
def get_hf_credentials(token: Optional[str] = None) -> Dict[str, Any]:
|
34 |
+
"""
|
35 |
+
Get credentials for Hugging Face's community TURN server.
|
36 |
+
|
37 |
+
Required setup:
|
38 |
+
1. Create a Hugging Face account at huggingface.co
|
39 |
+
2. Visit: https://huggingface.co/spaces/fastrtc/turn-server-login
|
40 |
+
3. Set HF_TOKEN environment variable or pass token directly
|
41 |
+
"""
|
42 |
+
token = token or os.environ.get("HF_TOKEN")
|
43 |
+
if not token:
|
44 |
+
raise ValueError("HF_TOKEN environment variable not set")
|
45 |
+
|
46 |
+
try:
|
47 |
+
return get_hf_turn_credentials(token=token)
|
48 |
+
except Exception as e:
|
49 |
+
raise Exception(f"Failed to get HF TURN credentials: {str(e)}")
|
50 |
+
|
51 |
+
|
52 |
+
def get_twilio_credentials(
|
53 |
+
account_sid: Optional[str] = None,
|
54 |
+
auth_token: Optional[str] = None
|
55 |
+
) -> Dict[str, Any]:
|
56 |
+
"""
|
57 |
+
Get credentials for Twilio's TURN server.
|
58 |
+
|
59 |
+
Required setup:
|
60 |
+
1. Create a free Twilio account at: https://login.twilio.com/u/signup
|
61 |
+
2. Get your Account SID and Auth Token from the Twilio Console
|
62 |
+
3. Set environment variables:
|
63 |
+
- TWILIO_ACCOUNT_SID (or pass directly)
|
64 |
+
- TWILIO_AUTH_TOKEN (or pass directly)
|
65 |
+
"""
|
66 |
+
account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID")
|
67 |
+
auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN")
|
68 |
+
|
69 |
+
if not account_sid or not auth_token:
|
70 |
+
raise ValueError("Twilio credentials not found. Set TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN env vars")
|
71 |
+
|
72 |
+
try:
|
73 |
+
return get_twilio_turn_credentials(account_sid=account_sid, auth_token=auth_token)
|
74 |
+
except Exception as e:
|
75 |
+
raise Exception(f"Failed to get Twilio TURN credentials: {str(e)}")
|
76 |
+
|
77 |
+
|
78 |
+
def get_cloudflare_credentials(
|
79 |
+
key_id: Optional[str] = None,
|
80 |
+
api_token: Optional[str] = None,
|
81 |
+
ttl: int = 86400
|
82 |
+
) -> Dict[str, Any]:
|
83 |
+
"""
|
84 |
+
Get credentials for Cloudflare's TURN server.
|
85 |
+
|
86 |
+
Required setup:
|
87 |
+
1. Create a free Cloudflare account
|
88 |
+
2. Go to Cloudflare dashboard -> Calls section
|
89 |
+
3. Create a TURN App and get the Turn Token ID and API Token
|
90 |
+
4. Set environment variables:
|
91 |
+
- TURN_KEY_ID
|
92 |
+
- TURN_KEY_API_TOKEN
|
93 |
+
|
94 |
+
Args:
|
95 |
+
key_id: Cloudflare Turn Token ID (optional, will use env var if not provided)
|
96 |
+
api_token: Cloudflare API Token (optional, will use env var if not provided)
|
97 |
+
ttl: Time-to-live for credentials in seconds (default: 24 hours)
|
98 |
+
"""
|
99 |
+
key_id = key_id or os.environ.get("TURN_KEY_ID")
|
100 |
+
api_token = api_token or os.environ.get("TURN_KEY_API_TOKEN")
|
101 |
+
|
102 |
+
if not key_id or not api_token:
|
103 |
+
raise ValueError("Cloudflare credentials not found. Set TURN_KEY_ID and TURN_KEY_API_TOKEN env vars")
|
104 |
+
|
105 |
+
response = requests.post(
|
106 |
+
f"https://rtc.live.cloudflare.com/v1/turn/keys/{key_id}/credentials/generate",
|
107 |
+
headers={
|
108 |
+
"Authorization": f"Bearer {api_token}",
|
109 |
+
"Content-Type": "application/json",
|
110 |
+
},
|
111 |
+
json={"ttl": ttl},
|
112 |
+
)
|
113 |
+
|
114 |
+
if response.ok:
|
115 |
+
return {"iceServers": [response.json()["iceServers"]]}
|
116 |
+
else:
|
117 |
+
raise Exception(
|
118 |
+
f"Failed to get Cloudflare TURN credentials: {response.status_code} {response.text}"
|
119 |
+
)
|