Spaces:
Sleeping
Sleeping
main.py
CHANGED
@@ -13,6 +13,7 @@ import logging
|
|
13 |
import io
|
14 |
from pydub import AudioSegment
|
15 |
from typing import List
|
|
|
16 |
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
@@ -46,17 +47,28 @@ model = joblib.load('models/xgb_test.pkl')
|
|
46 |
class ConnectionManager:
|
47 |
def __init__(self):
|
48 |
self.active_connections: List[WebSocket] = []
|
|
|
49 |
|
50 |
async def connect(self, websocket: WebSocket):
|
51 |
await websocket.accept()
|
52 |
self.active_connections.append(websocket)
|
|
|
53 |
|
54 |
def disconnect(self, websocket: WebSocket):
|
55 |
self.active_connections.remove(websocket)
|
|
|
56 |
|
57 |
-
async def send_message(self, message: str):
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
manager = ConnectionManager()
|
62 |
|
@@ -78,23 +90,14 @@ def extract_features(audio):
|
|
78 |
combined_features = np.hstack([mfccs, chroma, contrast, centroid])
|
79 |
return combined_features
|
80 |
|
81 |
-
async def process_audio_data(
|
|
|
82 |
try:
|
83 |
-
|
84 |
-
logger.info(f"Raw audio length: {len(audio_data)}")
|
85 |
-
# Attempt to convert audio data from webm/ogg to wav format using pydub
|
86 |
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="webm")
|
87 |
-
except Exception as e:
|
88 |
-
logger.error(f"Failed to convert audio data using pydub: {e}")
|
89 |
-
return
|
90 |
-
|
91 |
-
try:
|
92 |
-
# Export the audio segment to wav format
|
93 |
wav_io = io.BytesIO()
|
94 |
audio_segment.export(wav_io, format="wav")
|
95 |
wav_io.seek(0)
|
96 |
-
|
97 |
-
# Read the audio data
|
98 |
audio, sr = sf.read(wav_io, dtype='float32')
|
99 |
except Exception as e:
|
100 |
logger.error(f"Failed to read audio data: {e}")
|
@@ -103,20 +106,14 @@ async def process_audio_data(audio_data):
|
|
103 |
if audio.ndim > 1: # If audio has more than one channel, average them
|
104 |
audio = np.mean(audio, axis=1)
|
105 |
|
106 |
-
logger.info(f"The len of audio: {len(audio)}")
|
107 |
-
logger.info("Extracting features")
|
108 |
features = extract_features(audio)
|
109 |
features = features.reshape(1, -1)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
logger.error(f"Model prediction failed: {e}")
|
117 |
-
return
|
118 |
-
|
119 |
-
await manager.send_message(result)
|
120 |
|
121 |
@app.post("/start_detection")
|
122 |
async def start_detection():
|
@@ -138,7 +135,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
138 |
try:
|
139 |
while True:
|
140 |
data = await websocket.receive_bytes()
|
141 |
-
|
|
|
142 |
except WebSocketDisconnect:
|
143 |
manager.disconnect(websocket)
|
144 |
|
|
|
13 |
import io
|
14 |
from pydub import AudioSegment
|
15 |
from typing import List
|
16 |
+
import asyncio
|
17 |
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
|
|
47 |
class ConnectionManager:
|
48 |
def __init__(self):
|
49 |
self.active_connections: List[WebSocket] = []
|
50 |
+
self.audio_buffers = {}
|
51 |
|
52 |
async def connect(self, websocket: WebSocket):
|
53 |
await websocket.accept()
|
54 |
self.active_connections.append(websocket)
|
55 |
+
self.audio_buffers[websocket] = b'' # Initialize buffer for each connection
|
56 |
|
57 |
def disconnect(self, websocket: WebSocket):
|
58 |
self.active_connections.remove(websocket)
|
59 |
+
del self.audio_buffers[websocket] # Clean up buffer
|
60 |
|
61 |
+
async def send_message(self, websocket: WebSocket, message: str):
|
62 |
+
await websocket.send_text(message)
|
63 |
+
|
64 |
+
def add_to_buffer(self, websocket: WebSocket, data: bytes):
|
65 |
+
self.audio_buffers[websocket] += data # Accumulate data in the buffer
|
66 |
+
|
67 |
+
def get_buffer(self, websocket: WebSocket) -> bytes:
|
68 |
+
return self.audio_buffers[websocket]
|
69 |
+
|
70 |
+
def clear_buffer(self, websocket: WebSocket):
|
71 |
+
self.audio_buffers[websocket] = b'' # Clear the buffer
|
72 |
|
73 |
manager = ConnectionManager()
|
74 |
|
|
|
90 |
combined_features = np.hstack([mfccs, chroma, contrast, centroid])
|
91 |
return combined_features
|
92 |
|
93 |
+
async def process_audio_data(websocket: WebSocket):
|
94 |
+
audio_data = manager.get_buffer(websocket)
|
95 |
try:
|
96 |
+
# Convert audio data from webm/ogg to wav format using pydub
|
|
|
|
|
97 |
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="webm")
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
wav_io = io.BytesIO()
|
99 |
audio_segment.export(wav_io, format="wav")
|
100 |
wav_io.seek(0)
|
|
|
|
|
101 |
audio, sr = sf.read(wav_io, dtype='float32')
|
102 |
except Exception as e:
|
103 |
logger.error(f"Failed to read audio data: {e}")
|
|
|
106 |
if audio.ndim > 1: # If audio has more than one channel, average them
|
107 |
audio = np.mean(audio, axis=1)
|
108 |
|
|
|
|
|
109 |
features = extract_features(audio)
|
110 |
features = features.reshape(1, -1)
|
111 |
+
prediction = model.predict(features)
|
112 |
+
is_fake = prediction[0]
|
113 |
+
result = 'fake' if is_fake else 'real'
|
114 |
+
|
115 |
+
await manager.send_message(websocket, result)
|
116 |
+
manager.clear_buffer(websocket)
|
|
|
|
|
|
|
|
|
117 |
|
118 |
@app.post("/start_detection")
|
119 |
async def start_detection():
|
|
|
135 |
try:
|
136 |
while True:
|
137 |
data = await websocket.receive_bytes()
|
138 |
+
manager.add_to_buffer(websocket, data)
|
139 |
+
await process_audio_data(websocket)
|
140 |
except WebSocketDisconnect:
|
141 |
manager.disconnect(websocket)
|
142 |
|