kowalsky commited on
Commit
74a687d
·
1 Parent(s): 532ae6e
Files changed (1) hide show
  1. main.py +26 -28
main.py CHANGED
@@ -13,6 +13,7 @@ import logging
13
  import io
14
  from pydub import AudioSegment
15
  from typing import List
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
@@ -46,17 +47,28 @@ model = joblib.load('models/xgb_test.pkl')
46
  class ConnectionManager:
47
  def __init__(self):
48
  self.active_connections: List[WebSocket] = []
 
49
 
50
  async def connect(self, websocket: WebSocket):
51
  await websocket.accept()
52
  self.active_connections.append(websocket)
 
53
 
54
  def disconnect(self, websocket: WebSocket):
55
  self.active_connections.remove(websocket)
 
56
 
57
- async def send_message(self, message: str):
58
- for connection in self.active_connections:
59
- await connection.send_text(message)
 
 
 
 
 
 
 
 
60
 
61
  manager = ConnectionManager()
62
 
@@ -78,23 +90,14 @@ def extract_features(audio):
78
  combined_features = np.hstack([mfccs, chroma, contrast, centroid])
79
  return combined_features
80
 
81
- async def process_audio_data(audio_data):
 
82
  try:
83
- logger.info(f"Audio data type: {type(audio_data)}")
84
- logger.info(f"Raw audio length: {len(audio_data)}")
85
- # Attempt to convert audio data from webm/ogg to wav format using pydub
86
  audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="webm")
87
- except Exception as e:
88
- logger.error(f"Failed to convert audio data using pydub: {e}")
89
- return
90
-
91
- try:
92
- # Export the audio segment to wav format
93
  wav_io = io.BytesIO()
94
  audio_segment.export(wav_io, format="wav")
95
  wav_io.seek(0)
96
-
97
- # Read the audio data
98
  audio, sr = sf.read(wav_io, dtype='float32')
99
  except Exception as e:
100
  logger.error(f"Failed to read audio data: {e}")
@@ -103,20 +106,14 @@ async def process_audio_data(audio_data):
103
  if audio.ndim > 1: # If audio has more than one channel, average them
104
  audio = np.mean(audio, axis=1)
105
 
106
- logger.info(f"The len of audio: {len(audio)}")
107
- logger.info("Extracting features")
108
  features = extract_features(audio)
109
  features = features.reshape(1, -1)
110
-
111
- try:
112
- prediction = model.predict(features)
113
- is_fake = prediction[0]
114
- result = 'fake' if is_fake else 'real'
115
- except Exception as e:
116
- logger.error(f"Model prediction failed: {e}")
117
- return
118
-
119
- await manager.send_message(result)
120
 
121
  @app.post("/start_detection")
122
  async def start_detection():
@@ -138,7 +135,8 @@ async def websocket_endpoint(websocket: WebSocket):
138
  try:
139
  while True:
140
  data = await websocket.receive_bytes()
141
- await process_audio_data(data)
 
142
  except WebSocketDisconnect:
143
  manager.disconnect(websocket)
144
 
 
13
  import io
14
  from pydub import AudioSegment
15
  from typing import List
16
+ import asyncio
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
 
47
  class ConnectionManager:
48
  def __init__(self):
49
  self.active_connections: List[WebSocket] = []
50
+ self.audio_buffers = {}
51
 
52
  async def connect(self, websocket: WebSocket):
53
  await websocket.accept()
54
  self.active_connections.append(websocket)
55
+ self.audio_buffers[websocket] = b'' # Initialize buffer for each connection
56
 
57
  def disconnect(self, websocket: WebSocket):
58
  self.active_connections.remove(websocket)
59
+ del self.audio_buffers[websocket] # Clean up buffer
60
 
61
+ async def send_message(self, websocket: WebSocket, message: str):
62
+ await websocket.send_text(message)
63
+
64
+ def add_to_buffer(self, websocket: WebSocket, data: bytes):
65
+ self.audio_buffers[websocket] += data # Accumulate data in the buffer
66
+
67
+ def get_buffer(self, websocket: WebSocket) -> bytes:
68
+ return self.audio_buffers[websocket]
69
+
70
+ def clear_buffer(self, websocket: WebSocket):
71
+ self.audio_buffers[websocket] = b'' # Clear the buffer
72
 
73
  manager = ConnectionManager()
74
 
 
90
  combined_features = np.hstack([mfccs, chroma, contrast, centroid])
91
  return combined_features
92
 
93
+ async def process_audio_data(websocket: WebSocket):
94
+ audio_data = manager.get_buffer(websocket)
95
  try:
96
+ # Convert audio data from webm/ogg to wav format using pydub
 
 
97
  audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="webm")
 
 
 
 
 
 
98
  wav_io = io.BytesIO()
99
  audio_segment.export(wav_io, format="wav")
100
  wav_io.seek(0)
 
 
101
  audio, sr = sf.read(wav_io, dtype='float32')
102
  except Exception as e:
103
  logger.error(f"Failed to read audio data: {e}")
 
106
  if audio.ndim > 1: # If audio has more than one channel, average them
107
  audio = np.mean(audio, axis=1)
108
 
 
 
109
  features = extract_features(audio)
110
  features = features.reshape(1, -1)
111
+ prediction = model.predict(features)
112
+ is_fake = prediction[0]
113
+ result = 'fake' if is_fake else 'real'
114
+
115
+ await manager.send_message(websocket, result)
116
+ manager.clear_buffer(websocket)
 
 
 
 
117
 
118
  @app.post("/start_detection")
119
  async def start_detection():
 
135
  try:
136
  while True:
137
  data = await websocket.receive_bytes()
138
+ manager.add_to_buffer(websocket, data)
139
+ await process_audio_data(websocket)
140
  except WebSocketDisconnect:
141
  manager.disconnect(websocket)
142