kowalsky commited on
Commit
4bde5af
·
1 Parent(s): 7e42d52
Files changed (4) hide show
  1. Dockerfile +2 -2
  2. app.py +148 -0
  3. models/xgb_test.pkl +3 -0
  4. template/index.html +121 -0
Dockerfile CHANGED
@@ -13,6 +13,6 @@ RUN pip install --no-cache-dir -r requirements.txt
13
 
14
  COPY . .
15
 
16
- EXPOSE 8000
17
 
18
- CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
 
13
 
14
  COPY . .
15
 
16
+ EXPOSE 7860
17
 
18
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse, HTMLResponse
4
+ import sounddevice as sd
5
+ import numpy as np
6
+ import librosa
7
+ import joblib
8
+ import uvicorn
9
+ import threading
10
+ import asyncio
11
+ import logging
12
+ from typing import List
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ app = FastAPI()
18
+
19
+ @app.get("/", response_class=HTMLResponse)
20
+ async def get(request: Request):
21
+ logger.info("Saving the index page")
22
+ with open("templates/index.html") as f:
23
+ html_content = f.read()
24
+ return HTMLResponse(content=html_content, status_code=200)
25
+
26
+ @app.get("/health")
27
+ def health_check():
28
+ return {"status": "ok"}
29
+
30
+ # @app.get("/")
31
+ # def read_root():
32
+ # return {"status": "ok"}
33
+
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ duration = 2
43
+ sample_rate = 16000
44
+
45
+ is_detecting = False
46
+ detection_thread = None
47
+
48
+ model = joblib.load('models/xgb_test.pkl')
49
+
50
+ class ConnectionManager:
51
+ def __init__(self):
52
+ self.active_connections: List[WebSocket] = []
53
+
54
+ async def connect(self, websocket: WebSocket):
55
+ await websocket.accept()
56
+ self.active_connections.append(websocket)
57
+
58
+ def disconnect(self, websocket: WebSocket):
59
+ self.active_connections.remove(websocket)
60
+
61
+ async def send_message(self, message: str):
62
+ for connection in self.active_connections:
63
+ await connection.send_text(message)
64
+
65
+ manager = ConnectionManager()
66
+
67
+ def extract_features(audio):
68
+ sr = 16000
69
+
70
+ mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
71
+ mfccs = np.mean(mfccs, axis=1)
72
+
73
+ chroma = librosa.feature.chroma_stft(y=audio, sr=sr)
74
+ chroma = np.mean(chroma, axis=1)
75
+
76
+ contrast = librosa.feature.spectral_contrast(y=audio, sr=sr)
77
+ contrast = np.mean(contrast, axis=1)
78
+
79
+ centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)
80
+ centroid = np.mean(centroid, axis=1)
81
+
82
+ combined_features = np.hstack([mfccs, chroma, contrast, centroid])
83
+ return combined_features
84
+
85
+ async def audio_callback(indata, frames, time, status):
86
+ if status:
87
+ print(status)
88
+
89
+ audio_data = indata[:, 0]
90
+ print(f"Audio data: {audio_data[:10]}... (length: {len(audio_data)})")
91
+ logger.info(f"Audio data: {audio_data[:10]}... (length: {len(audio_data)})")
92
+
93
+ features = extract_features(audio_data)
94
+ features = features.reshape(1, -1)
95
+ prediction = model.predict(features)
96
+ is_fake = prediction[0]
97
+
98
+ print(f"Prediction: {is_fake}")
99
+ logger.info(f"Prediction: {is_fake}")
100
+
101
+ result = 'fake' if is_fake else 'real'
102
+
103
+ print(f"Detected {result} audio")
104
+ logger.info(f"Detected {result} audio")
105
+
106
+ await manager.send_message(result)
107
+
108
+ def detect_fake_audio():
109
+ global is_detecting
110
+ try:
111
+ with sd.InputStream(callback=lambda indata, frames, time, status: asyncio.run(audio_callback(indata, frames, time, status)), channels=1, samplerate=sample_rate, blocksize=int(sample_rate * duration)):
112
+ print("Listening...")
113
+ logger.info("Listening...")
114
+ while is_detecting:
115
+ sd.sleep(duration * 1000)
116
+ except Exception as e:
117
+ print(f"Exception: {str(e)}")
118
+ logger.info(f"Exception: {str(e)}")
119
+
120
+ @app.post("/start_detection")
121
+ async def start_detection():
122
+ global is_detecting, detection_thread
123
+
124
+ if not is_detecting:
125
+ is_detecting = True
126
+ detection_thread = threading.Thread(target=detect_fake_audio)
127
+ detection_thread.start()
128
+ return JSONResponse(content={'status': 'detection_started'})
129
+
130
+ @app.post("/stop_detection")
131
+ async def stop_detection():
132
+ global is_detecting, detection_thread
133
+ is_detecting = False
134
+ if detection_thread:
135
+ detection_thread.join()
136
+ return JSONResponse(content={'status': 'detection_stopped'})
137
+
138
+ @app.websocket("/ws")
139
+ async def websocket_endpoint(websocket: WebSocket):
140
+ await manager.connect(websocket)
141
+ try:
142
+ while True:
143
+ await websocket.receive_text()
144
+ except WebSocketDisconnect:
145
+ manager.disconnect(websocket)
146
+
147
+ if __name__ == '__main__':
148
+ uvicorn.run(app, host="0.0.0.0", port=7860)
models/xgb_test.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:606852a5fc31ad652bc03333af296a578a0c45cbec15a159e348045533805ab9
3
+ size 113963
template/index.html CHANGED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Fake Audio Detection</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ margin: 20px;
11
+ }
12
+ #console {
13
+ border: 1px solid #ddd;
14
+ padding: 10px;
15
+ height: 200px;
16
+ overflow-y: scroll;
17
+ }
18
+ .log {
19
+ margin: 0;
20
+ padding: 5px;
21
+ }
22
+ .log.real {
23
+ color: green;
24
+ }
25
+ .log.fake {
26
+ color: red;
27
+ }
28
+ </style>
29
+ </head>
30
+ <body>
31
+ <h1>Fake Audio Detection</h1>
32
+ <button id="startButton">Start Detection</button>
33
+ <button id="stopButton" disabled>Stop Detection</button>
34
+ <div id="console"></div>
35
+
36
+ <script>
37
+ let isDetecting = false;
38
+ const startButton = document.getElementById('startButton');
39
+ const stopButton = document.getElementById('stopButton');
40
+ const consoleDiv = document.getElementById('console');
41
+ let websocket = null;
42
+
43
+ function logMessage(message, type) {
44
+ const log = document.createElement('p');
45
+ log.className = `log ${type}`;
46
+ log.textContent = message;
47
+ consoleDiv.appendChild(log);
48
+ consoleDiv.scrollTop = consoleDiv.scrollHeight;
49
+ }
50
+
51
+ async function startDetection() {
52
+ if (isDetecting) {
53
+ logMessage('Detection is already running...', 'info');
54
+ return;
55
+ }
56
+ isDetecting = true;
57
+ logMessage('Starting detection...', 'info');
58
+ startButton.disabled = true;
59
+ stopButton.disabled = false;
60
+
61
+ try {
62
+ const response = await fetch('http://localhost:7860/start_detection', {
63
+ method: 'POST',
64
+ });
65
+
66
+ if (!response.ok) {
67
+ throw new Error('Network response was not ok');
68
+ }
69
+
70
+ const result = await response.json();
71
+ logMessage(`Detection started: ${result.status}`, 'info');
72
+
73
+ websocket = new WebSocket('ws://localhost:7860/ws');
74
+ websocket.onmessage = function(event) {
75
+ const data = event.data;
76
+ logMessage(`Detected ${data} audio`, data);
77
+ };
78
+ websocket.onclose = function() {
79
+ logMessage('WebSocket connection closed', 'info');
80
+ };
81
+ } catch (error) {
82
+ logMessage(`Error: ${error.message}`, 'error');
83
+ }
84
+ }
85
+
86
+ async function stopDetection() {
87
+ if (!isDetecting) {
88
+ logMessage('Detection is not running...', 'info');
89
+ return;
90
+ }
91
+ isDetecting = false;
92
+ logMessage('Stopping detection...', 'info');
93
+ startButton.disabled = false;
94
+ stopButton.disabled = true;
95
+
96
+ try {
97
+ const response = await fetch('http://localhost:7860/stop_detection', {
98
+ method: 'POST',
99
+ });
100
+
101
+ if (!response.ok) {
102
+ throw new Error('Network response was not ok');
103
+ }
104
+
105
+ const result = await response.json();
106
+ logMessage(`Detection stopped: ${result.status}`, 'info');
107
+
108
+ if (websocket) {
109
+ websocket.close();
110
+ websocket = null;
111
+ }
112
+ } catch (error) {
113
+ logMessage(`Error: ${error.message}`, 'error');
114
+ }
115
+ }
116
+
117
+ startButton.addEventListener('click', startDetection);
118
+ stopButton.addEventListener('click', stopDetection);
119
+ </script>
120
+ </body>
121
+ </html>