3v324v23 commited on
Commit
3ead8a3
·
1 Parent(s): ad71b2b

Add application file

Browse files
Files changed (5) hide show
  1. app.py +228 -0
  2. requirements.txt +13 -0
  3. static/script.js +73 -0
  4. static/style.css +84 -0
  5. templates/index.html +25 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # File máy chủ Flask hoàn chỉnh để triển khai trên Hugging Face Spaces
3
+
4
+ import os
5
+ import joblib
6
+ import numpy as np
7
+ import librosa
8
+ from flask import Flask, request, jsonify, render_template
9
+ from werkzeug.utils import secure_filename
10
+
11
+ # --- Cấu hình TensorFlow và các thư viện AI ---
12
+ # Đặt biến môi trường để giảm thiểu log không cần thiết của TensorFlow.
13
+ # Phải thực hiện trước khi import tensorflow.
14
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
15
+ import tensorflow as tf
16
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
17
+ import torch
18
+
19
+ # --- KHỞI TẠO ỨNG DỤNG FLASK ---
20
+ app = Flask(__name__)
21
+
22
+ # Cấu hình thư mục tạm để lưu file audio người dùng tải lên
23
+ UPLOAD_FOLDER = 'uploads/'
24
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
25
+ # Tạo thư mục nếu nó chưa tồn tại
26
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
27
+
28
+
29
+ # --- TẢI TẤT CẢ CÁC MÔ HÌNH KHI SERVER KHỞI ĐỘNG ---
30
+ # Đây là bước quan trọng để tối ưu hóa hiệu suất. Mô hình chỉ được tải một lần
31
+ # thay vì tải lại mỗi khi có yêu cầu dự đoán.
32
+ print(">>> Đang tải các mô hình AI, quá trình này có thể mất một lúc...")
33
+
34
+ try:
35
+ MODEL_PATH = 'models/'
36
+
37
+ # Tải các thành phần tiền xử lý và các mô hình machine learning
38
+ scaler = joblib.load(os.path.join(MODEL_PATH, 'scaler.pkl'))
39
+ label_encoder = joblib.load(os.path.join(MODEL_PATH, 'label_encoder.pkl'))
40
+ model_xgb = joblib.load(os.path.join(MODEL_PATH, 'xgboost.pkl'))
41
+ model_lgb = joblib.load(os.path.join(MODEL_PATH, 'lightgbm.pkl'))
42
+
43
+ # Tải mô hình deep learning (CNN)
44
+ model_cnn = tf.keras.models.load_model(os.path.join(MODEL_PATH, 'cnn.keras'))
45
+
46
+ # Tải mô hình xử lý ngôn ngữ tự nhiên cho audio (Wav2Vec2)
47
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
48
+ wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
49
+
50
+ print(">>> OK! Tất cả các mô hình đã được tải thành công!")
51
+
52
+ except Exception as e:
53
+ print(f"!!! LỖI NGHIÊM TRỌNG: Không thể tải một hoặc nhiều mô hình. Lỗi: {e}")
54
+ print("!!! Vui lòng kiểm tra lại đường dẫn và sự tồn tại của các file trong thư mục 'models/'.")
55
+ # Thoát ứng dụng nếu không tải được mô hình
56
+ exit()
57
+
58
+ # --- CÁC HÀM TRÍCH XUẤT ĐẶC TRƯNG ---
59
+ # Các hàm này phải giống hệt với các hàm đã được sử dụng trong quá trình huấn luyện
60
+ # để đảm bảo tính nhất quán của dữ liệu đầu vào cho mô hình.
61
+
62
+ # Các hằng số cấu hình
63
+ SAMPLE_RATE = 22050
64
+ MAX_LENGTH_SECONDS = 5.0
65
+ MAX_SAMPLES = int(SAMPLE_RATE * MAX_LENGTH_SECONDS)
66
+ N_MELS = 128
67
+ TRADITIONAL_FEATURE_SIZE = 570 # (128*4 cho melspec + 13*2 cho mfcc + ...) - Phải khớp với lúc train
68
+ WAV2VEC_FEATURE_SIZE = 768
69
+ SPECTROGRAM_SHAPE = (224, 224, 3)
70
+
71
+ def _extract_traditional_features(y, sr):
72
+ """Trích xuất các đặc trưng âm thanh truyền thống (MFCC, Mel Spectrogram, etc.)."""
73
+ try:
74
+ # Mel Spectrogram features (mean, std, max, min)
75
+ mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS)
76
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
77
+ features = np.mean(mel_spec_db, axis=1)
78
+ features = np.append(features, np.std(mel_spec_db, axis=1))
79
+ features = np.append(features, np.max(mel_spec_db, axis=1))
80
+ features = np.append(features, np.min(mel_spec_db, axis=1))
81
+
82
+ # MFCC features (mean, std)
83
+ mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
84
+ features = np.append(features, np.mean(mfccs, axis=1))
85
+ features = np.append(features, np.std(mfccs, axis=1))
86
+
87
+ # Cần thêm các đặc trưng khác nếu có trong lúc train để đủ `TRADITIONAL_FEATURE_SIZE`
88
+ # Ví dụ: chroma, spectral_contrast, etc.
89
+ # Ở đây, chúng ta sẽ pad/truncate để đảm bảo kích thước
90
+ if len(features) > TRADITIONAL_FEATURE_SIZE:
91
+ features = features[:TRADITIONAL_FEATURE_SIZE]
92
+ elif len(features) < TRADITIONAL_FEATURE_SIZE:
93
+ features = np.pad(features, (0, TRADITIONAL_FEATURE_SIZE - len(features)), mode='constant')
94
+
95
+ return features
96
+
97
+ except Exception as e:
98
+ print(f"Lỗi trích xuất đặc trưng truyền thống: {e}")
99
+ return np.zeros(TRADITIONAL_FEATURE_SIZE)
100
+
101
+ def _extract_wav2vec_features(y, sr):
102
+ """Trích xuất đặc trưng từ mô hình Wav2Vec2."""
103
+ try:
104
+ # Wav2Vec2 yêu cầu sample rate 16000
105
+ y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000)
106
+ inputs = wav2vec_processor(y_16k, sampling_rate=16000, return_tensors="pt", padding=True)
107
+ with torch.no_grad():
108
+ outputs = wav2vec_model(**inputs)
109
+ # Lấy trung bình các hidden states cuối cùng để có một vector đại diện
110
+ features = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
111
+ return features
112
+ except Exception as e:
113
+ print(f"Lỗi trích xuất Wav2Vec2: {e}")
114
+ return np.zeros(WAV2VEC_FEATURE_SIZE)
115
+
116
+ def _create_spectrogram_image(y, sr):
117
+ """Tạo ảnh spectrogram cho mô hình CNN."""
118
+ try:
119
+ mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=SPECTROGRAM_SHAPE[0])
120
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
121
+ # Chuẩn hóa giá trị về khoảng [0, 255]
122
+ mel_spec_norm = ((mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8) * 255).astype(np.uint8)
123
+ # Chuyển thành ảnh 3 kênh (RGB)
124
+ img = tf.keras.preprocessing.image.array_to_img(np.stack([mel_spec_norm]*3, axis=-1))
125
+ # Resize về kích thước đầu vào của CNN
126
+ img = img.resize((SPECTROGRAM_SHAPE[1], SPECTROGRAM_SHAPE[0]))
127
+ return np.array(img)
128
+ except Exception as e:
129
+ print(f"Lỗi tạo ảnh spectrogram: {e}")
130
+ return np.zeros(SPECTROGRAM_SHAPE)
131
+
132
+ def process_audio_file(file_path):
133
+ """Hàm tổng hợp: Tải file audio và gọi các hàm trích xuất đặc trưng."""
134
+ try:
135
+ y, sr = librosa.load(file_path, sr=SAMPLE_RATE)
136
+
137
+ # Chuẩn hóa độ dài audio về MAX_SAMPLES
138
+ if len(y) > MAX_SAMPLES:
139
+ y = y[:MAX_SAMPLES]
140
+ else:
141
+ y = np.pad(y, (0, MAX_SAMPLES - len(y)), mode='constant')
142
+
143
+ # Trích xuất đồng thời các bộ đặc trưng
144
+ traditional_features = _extract_traditional_features(y, sr)
145
+ wav2vec_features = _extract_wav2vec_features(y, sr)
146
+ spectrogram = _create_spectrogram_image(y, sr)
147
+
148
+ return traditional_features, wav2vec_features, spectrogram
149
+ except Exception as e:
150
+ print(f"Lỗi nghiêm trọng khi xử lý file audio {file_path}: {e}")
151
+ return None, None, None
152
+
153
+
154
+ # --- ĐỊNH NGHĨA CÁC ROUTE (API ENDPOINTS) CỦA ỨNG DỤNG ---
155
+ @app.route('/', methods=['GET'])
156
+ def home():
157
+ """Render trang chủ của ứng dụng."""
158
+ return render_template('index.html')
159
+
160
+ @app.route('/predict', methods=['POST'])
161
+ def predict():
162
+ """API endpoint để nhận file audio, xử lý và trả về kết quả dự đoán."""
163
+ if 'audio_file' not in request.files:
164
+ return jsonify({'error': 'Không có file audio nào trong yêu cầu.'}), 400
165
+
166
+ file = request.files['audio_file']
167
+ if file.filename == '':
168
+ return jsonify({'error': 'Tên file không hợp lệ.'}), 400
169
+
170
+ try:
171
+ # Lưu file audio vào thư mục tạm một cách an toàn
172
+ filename = secure_filename(file.filename)
173
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
174
+ file.save(filepath)
175
+
176
+ # Xử lý file audio để trích xuất tất cả các đặc trưng cần thiết
177
+ trad_feats, w2v_feats, spec_img = process_audio_file(filepath)
178
+
179
+ if trad_feats is None:
180
+ return jsonify({'error': 'Không thể xử lý file audio.'}), 500
181
+
182
+ # --- Chuẩn bị dữ liệu đầu vào cho từng mô hình ---
183
+ # 1. Dữ liệu cho XGBoost và LightGBM (kết hợp và scale)
184
+ combined_feats = np.concatenate([trad_feats, w2v_feats]).reshape(1, -1)
185
+ scaled_feats = scaler.transform(combined_feats)
186
+
187
+ # 2. Dữ liệu cho CNN (chuẩn hóa và thêm chiều batch)
188
+ spec_img = spec_img / 255.0
189
+ spec_img = np.expand_dims(spec_img, axis=0)
190
+
191
+ # --- Lấy dự đoán từ tất cả các mô hình ---
192
+ pred_xgb = model_xgb.predict_proba(scaled_feats)[0][1]
193
+ pred_lgb = model_lgb.predict_proba(scaled_feats)[0][1]
194
+ pred_cnn = model_cnn.predict(spec_img, verbose=0)[0][0]
195
+
196
+ # --- Ensemble: Kết hợp kết quả bằng cách lấy trung bình xác suất ---
197
+ final_prediction_prob = (pred_xgb + pred_lgb + pred_cnn) / 3
198
+ # Quyết định nhãn cuối cùng dựa trên ngưỡng 0.5
199
+ final_prediction_label_index = 1 if final_prediction_prob > 0.5 else 0
200
+
201
+ # Chuyển đổi chỉ số nhãn (0 hoặc 1) thành chuỗi ('male'/'female')
202
+ result_label_text = label_encoder.inverse_transform([final_prediction_label_index])[0]
203
+
204
+ # Xóa file audio tạm sau khi xử lý xong
205
+ os.remove(filepath)
206
+
207
+ print(f"Phân tích hoàn tất. Kết quả: {result_label_text.upper()} (Xác suất: {final_prediction_prob:.2f})")
208
+
209
+ # Trả về kết quả dưới dạng JSON
210
+ return jsonify({
211
+ 'prediction': result_label_text.capitalize(),
212
+ 'probability': f"{final_prediction_prob:.2f}"
213
+ })
214
+
215
+ except Exception as e:
216
+ print(f"Đã xảy ra lỗi trong quá trình dự đoán: {e}")
217
+ import traceback
218
+ traceback.print_exc()
219
+ return jsonify({'error': 'Đã xảy ra lỗi không xác ��ịnh trên máy chủ.'}), 500
220
+
221
+ # --- ĐIỂM BẮT ĐẦU CHẠY ỨNG DỤNG ---
222
+ # Đoạn mã này được cấu hình để hoạt động tốt trên cả máy local và Hugging Face Spaces.
223
+ if __name__ == '__main__':
224
+ # Hugging Face Spaces sẽ đặt biến môi trường PORT. Nếu không có, dùng 7860 làm mặc định.
225
+ port = int(os.environ.get("PORT", 7860))
226
+ # Chạy trên host '0.0.0.0' để ứng dụng có thể được truy cập từ bên ngoài container Docker.
227
+ app.run(host='0.0.0.0', port=port)
228
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+
3
+ Flask
4
+ gunicorn
5
+ tensorflow==2.15.0 # Ghi rõ phiên bản để đảm bảo tương thích
6
+ torch
7
+ transformers
8
+ joblib
9
+ scikit-learn
10
+ xgboost
11
+ lightgbm
12
+ librosa
13
+ numpy
static/script.js ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // static/script.js
2
+ document.addEventListener('DOMContentLoaded', () => {
3
+ const recordButton = document.getElementById('recordButton');
4
+ const stopButton = document.getElementById('stopButton');
5
+ const statusDiv = document.getElementById('status');
6
+ const resultDiv = document.getElementById('result');
7
+
8
+ let mediaRecorder;
9
+ let audioChunks = [];
10
+
11
+ // --- Bắt đầu ghi âm ---
12
+ recordButton.addEventListener('click', async () => {
13
+ try {
14
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
15
+
16
+ mediaRecorder = new MediaRecorder(stream);
17
+
18
+ mediaRecorder.ondataavailable = event => {
19
+ audioChunks.push(event.data);
20
+ };
21
+
22
+ mediaRecorder.onstop = () => {
23
+ const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
24
+ sendAudioToServer(audioBlob);
25
+ audioChunks = [];
26
+ };
27
+
28
+ mediaRecorder.start();
29
+ statusDiv.textContent = '🔴 Đang ghi âm...';
30
+ recordButton.disabled = true;
31
+ stopButton.disabled = false;
32
+
33
+ } catch (error) {
34
+ console.error('Lỗi khi truy cập micro:', error);
35
+ statusDiv.textContent = 'Lỗi: Không thể truy cập micro.';
36
+ }
37
+ });
38
+
39
+ // --- Dừng ghi âm ---
40
+ stopButton.addEventListener('click', () => {
41
+ mediaRecorder.stop();
42
+ statusDiv.textContent = 'Đang xử lý... Vui lòng chờ.';
43
+ recordButton.disabled = false;
44
+ stopButton.disabled = true;
45
+ });
46
+
47
+ // --- Gửi audio lên server ---
48
+ async function sendAudioToServer(audioBlob) {
49
+ const formData = new FormData();
50
+ formData.append('audio_file', audioBlob, 'recording.wav');
51
+
52
+ try {
53
+ resultDiv.textContent = '...'; // Reset kết quả
54
+ const response = await fetch('/predict', {
55
+ method: 'POST',
56
+ body: formData,
57
+ });
58
+
59
+ if (response.ok) {
60
+ const data = await response.json();
61
+ resultDiv.textContent = data.prediction;
62
+ statusDiv.textContent = 'Hoàn thành! Sẵn sàng ghi âm lần nữa.';
63
+ } else {
64
+ const errorData = await response.json();
65
+ throw new Error(errorData.error || 'Lỗi không xác định từ server.');
66
+ }
67
+ } catch (error) {
68
+ console.error('Lỗi khi gửi audio:', error);
69
+ statusDiv.textContent = 'Đã xảy ra lỗi khi gửi audio.';
70
+ resultDiv.textContent = 'Lỗi!';
71
+ }
72
+ }
73
+ });
static/style.css ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* static/style.css */
2
+ body {
3
+ font-family: Arial, sans-serif;
4
+ background-color: #f0f8ff;
5
+ display: flex;
6
+ justify-content: center;
7
+ align-items: center;
8
+ height: 100vh;
9
+ margin: 0;
10
+ color: #333;
11
+ }
12
+
13
+ .container {
14
+ background-color: white;
15
+ padding: 40px;
16
+ border-radius: 12px;
17
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
18
+ text-align: center;
19
+ max-width: 500px;
20
+ }
21
+
22
+ h1 {
23
+ color: #2c3e50;
24
+ margin-bottom: 10px;
25
+ }
26
+
27
+ p {
28
+ color: #7f8c8d;
29
+ margin-bottom: 30px;
30
+ }
31
+
32
+ .controls button {
33
+ padding: 15px 30px;
34
+ border: none;
35
+ border-radius: 8px;
36
+ font-size: 16px;
37
+ cursor: pointer;
38
+ margin: 0 10px;
39
+ transition: all 0.3s ease;
40
+ }
41
+
42
+ #recordButton {
43
+ background-color: #e74c3c;
44
+ color: white;
45
+ }
46
+
47
+ #recordButton:hover {
48
+ background-color: #c0392b;
49
+ }
50
+
51
+ #stopButton {
52
+ background-color: #3498db;
53
+ color: white;
54
+ }
55
+
56
+ #stopButton:disabled {
57
+ background-color: #bdc3c7;
58
+ cursor: not-allowed;
59
+ }
60
+
61
+ #status {
62
+ margin-top: 20px;
63
+ font-style: italic;
64
+ color: #95a5a6;
65
+ }
66
+
67
+ .result-container {
68
+ margin-top: 30px;
69
+ padding: 20px;
70
+ background-color: #ecf0f1;
71
+ border-radius: 8px;
72
+ }
73
+
74
+ h2 {
75
+ color: #2c3e50;
76
+ margin-top: 0;
77
+ }
78
+
79
+ #result {
80
+ font-size: 24px;
81
+ font-weight: bold;
82
+ color: #2980b9;
83
+ min-height: 30px;
84
+ }
templates/index.html ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="vi">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Phân loại Giọng chim Đực/Cái</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
8
+ </head>
9
+ <body>
10
+ <div class="container">
11
+ <h1>Nhận dạng Giọng chim Đực/Cái</h1>
12
+ <p>Bấm nút "Ghi âm", tạo ra âm thanh của chim và bấm "Dừng" để xem kết quả.</p>
13
+ <div class="controls">
14
+ <button id="recordButton">Ghi âm</button>
15
+ <button id="stopButton" disabled>Dừng</button>
16
+ </div>
17
+ <div id="status">Sẵn sàng ghi âm...</div>
18
+ <div class="result-container">
19
+ <h2>Kết quả:</h2>
20
+ <div id="result">---</div>
21
+ </div>
22
+ </div>
23
+ <script src="{{ url_for('static', filename='script.js') }}"></script>
24
+ </body>
25
+ </html>