File size: 10,547 Bytes
3ead8a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# app.py
# File máy chủ Flask hoàn chỉnh để triển khai trên Hugging Face Spaces

import os
import joblib
import numpy as np
import librosa
from flask import Flask, request, jsonify, render_template
from werkzeug.utils import secure_filename

# --- Cấu hình TensorFlow và các thư viện AI ---
# Đặt biến môi trường để giảm thiểu log không cần thiết của TensorFlow.
# Phải thực hiện trước khi import tensorflow.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import tensorflow as tf
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch

# --- KHỞI TẠO ỨNG DỤNG FLASK ---
app = Flask(__name__)

# Cấu hình thư mục tạm để lưu file audio người dùng tải lên
UPLOAD_FOLDER = 'uploads/'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# Tạo thư mục nếu nó chưa tồn tại
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)


# --- TẢI TẤT CẢ CÁC MÔ HÌNH KHI SERVER KHỞI ĐỘNG ---
# Đây là bước quan trọng để tối ưu hóa hiệu suất. Mô hình chỉ được tải một lần
# thay vì tải lại mỗi khi có yêu cầu dự đoán.
print(">>> Đang tải các mô hình AI, quá trình này có thể mất một lúc...")

try:
    MODEL_PATH = 'models/'
    
    # Tải các thành phần tiền xử lý và các mô hình machine learning
    scaler = joblib.load(os.path.join(MODEL_PATH, 'scaler.pkl'))
    label_encoder = joblib.load(os.path.join(MODEL_PATH, 'label_encoder.pkl'))
    model_xgb = joblib.load(os.path.join(MODEL_PATH, 'xgboost.pkl'))
    model_lgb = joblib.load(os.path.join(MODEL_PATH, 'lightgbm.pkl'))
    
    # Tải mô hình deep learning (CNN)
    model_cnn = tf.keras.models.load_model(os.path.join(MODEL_PATH, 'cnn.keras'))
    
    # Tải mô hình xử lý ngôn ngữ tự nhiên cho audio (Wav2Vec2)
    wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
    
    print(">>> OK! Tất cả các mô hình đã được tải thành công!")

except Exception as e:
    print(f"!!! LỖI NGHIÊM TRỌNG: Không thể tải một hoặc nhiều mô hình. Lỗi: {e}")
    print("!!! Vui lòng kiểm tra lại đường dẫn và sự tồn tại của các file trong thư mục 'models/'.")
    # Thoát ứng dụng nếu không tải được mô hình
    exit()

# --- CÁC HÀM TRÍCH XUẤT ĐẶC TRƯNG ---
# Các hàm này phải giống hệt với các hàm đã được sử dụng trong quá trình huấn luyện
# để đảm bảo tính nhất quán của dữ liệu đầu vào cho mô hình.

# Các hằng số cấu hình
SAMPLE_RATE = 22050
MAX_LENGTH_SECONDS = 5.0
MAX_SAMPLES = int(SAMPLE_RATE * MAX_LENGTH_SECONDS)
N_MELS = 128
TRADITIONAL_FEATURE_SIZE = 570 # (128*4 cho melspec + 13*2 cho mfcc + ...) - Phải khớp với lúc train
WAV2VEC_FEATURE_SIZE = 768
SPECTROGRAM_SHAPE = (224, 224, 3)

def _extract_traditional_features(y, sr):
    """Trích xuất các đặc trưng âm thanh truyền thống (MFCC, Mel Spectrogram, etc.)."""
    try:
        # Mel Spectrogram features (mean, std, max, min)
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        features = np.mean(mel_spec_db, axis=1)
        features = np.append(features, np.std(mel_spec_db, axis=1))
        features = np.append(features, np.max(mel_spec_db, axis=1))
        features = np.append(features, np.min(mel_spec_db, axis=1))
        
        # MFCC features (mean, std)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
        features = np.append(features, np.mean(mfccs, axis=1))
        features = np.append(features, np.std(mfccs, axis=1))
        
        # Cần thêm các đặc trưng khác nếu có trong lúc train để đủ `TRADITIONAL_FEATURE_SIZE`
        # Ví dụ: chroma, spectral_contrast, etc.
        # Ở đây, chúng ta sẽ pad/truncate để đảm bảo kích thước
        if len(features) > TRADITIONAL_FEATURE_SIZE:
            features = features[:TRADITIONAL_FEATURE_SIZE]
        elif len(features) < TRADITIONAL_FEATURE_SIZE:
            features = np.pad(features, (0, TRADITIONAL_FEATURE_SIZE - len(features)), mode='constant')
            
        return features

    except Exception as e:
        print(f"Lỗi trích xuất đặc trưng truyền thống: {e}")
        return np.zeros(TRADITIONAL_FEATURE_SIZE)

def _extract_wav2vec_features(y, sr):
    """Trích xuất đặc trưng từ mô hình Wav2Vec2."""
    try:
        # Wav2Vec2 yêu cầu sample rate 16000
        y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000)
        inputs = wav2vec_processor(y_16k, sampling_rate=16000, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = wav2vec_model(**inputs)
        # Lấy trung bình các hidden states cuối cùng để có một vector đại diện
        features = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        return features
    except Exception as e:
        print(f"Lỗi trích xuất Wav2Vec2: {e}")
        return np.zeros(WAV2VEC_FEATURE_SIZE)

def _create_spectrogram_image(y, sr):
    """Tạo ảnh spectrogram cho mô hình CNN."""
    try:
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=SPECTROGRAM_SHAPE[0])
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        # Chuẩn hóa giá trị về khoảng [0, 255]
        mel_spec_norm = ((mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8) * 255).astype(np.uint8)
        # Chuyển thành ảnh 3 kênh (RGB)
        img = tf.keras.preprocessing.image.array_to_img(np.stack([mel_spec_norm]*3, axis=-1))
        # Resize về kích thước đầu vào của CNN
        img = img.resize((SPECTROGRAM_SHAPE[1], SPECTROGRAM_SHAPE[0]))
        return np.array(img)
    except Exception as e:
        print(f"Lỗi tạo ảnh spectrogram: {e}")
        return np.zeros(SPECTROGRAM_SHAPE)

def process_audio_file(file_path):
    """Hàm tổng hợp: Tải file audio và gọi các hàm trích xuất đặc trưng."""
    try:
        y, sr = librosa.load(file_path, sr=SAMPLE_RATE)
        
        # Chuẩn hóa độ dài audio về MAX_SAMPLES
        if len(y) > MAX_SAMPLES:
            y = y[:MAX_SAMPLES]
        else:
            y = np.pad(y, (0, MAX_SAMPLES - len(y)), mode='constant')

        # Trích xuất đồng thời các bộ đặc trưng
        traditional_features = _extract_traditional_features(y, sr)
        wav2vec_features = _extract_wav2vec_features(y, sr)
        spectrogram = _create_spectrogram_image(y, sr)

        return traditional_features, wav2vec_features, spectrogram
    except Exception as e:
        print(f"Lỗi nghiêm trọng khi xử lý file audio {file_path}: {e}")
        return None, None, None


# --- ĐỊNH NGHĨA CÁC ROUTE (API ENDPOINTS) CỦA ỨNG DỤNG ---
@app.route('/', methods=['GET'])
def home():
    """Render trang chủ của ứng dụng."""
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    """API endpoint để nhận file audio, xử lý và trả về kết quả dự đoán."""
    if 'audio_file' not in request.files:
        return jsonify({'error': 'Không có file audio nào trong yêu cầu.'}), 400

    file = request.files['audio_file']
    if file.filename == '':
        return jsonify({'error': 'Tên file không hợp lệ.'}), 400

    try:
        # Lưu file audio vào thư mục tạm một cách an toàn
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)

        # Xử lý file audio để trích xuất tất cả các đặc trưng cần thiết
        trad_feats, w2v_feats, spec_img = process_audio_file(filepath)

        if trad_feats is None:
             return jsonify({'error': 'Không thể xử lý file audio.'}), 500

        # --- Chuẩn bị dữ liệu đầu vào cho từng mô hình ---
        # 1. Dữ liệu cho XGBoost và LightGBM (kết hợp và scale)
        combined_feats = np.concatenate([trad_feats, w2v_feats]).reshape(1, -1)
        scaled_feats = scaler.transform(combined_feats)
        
        # 2. Dữ liệu cho CNN (chuẩn hóa và thêm chiều batch)
        spec_img = spec_img / 255.0
        spec_img = np.expand_dims(spec_img, axis=0)

        # --- Lấy dự đoán từ tất cả các mô hình ---
        pred_xgb = model_xgb.predict_proba(scaled_feats)[0][1]
        pred_lgb = model_lgb.predict_proba(scaled_feats)[0][1]
        pred_cnn = model_cnn.predict(spec_img, verbose=0)[0][0]

        # --- Ensemble: Kết hợp kết quả bằng cách lấy trung bình xác suất ---
        final_prediction_prob = (pred_xgb + pred_lgb + pred_cnn) / 3
        # Quyết định nhãn cuối cùng dựa trên ngưỡng 0.5
        final_prediction_label_index = 1 if final_prediction_prob > 0.5 else 0

        # Chuyển đổi chỉ số nhãn (0 hoặc 1) thành chuỗi ('male'/'female')
        result_label_text = label_encoder.inverse_transform([final_prediction_label_index])[0]

        # Xóa file audio tạm sau khi xử lý xong
        os.remove(filepath)
        
        print(f"Phân tích hoàn tất. Kết quả: {result_label_text.upper()} (Xác suất: {final_prediction_prob:.2f})")

        # Trả về kết quả dưới dạng JSON
        return jsonify({
            'prediction': result_label_text.capitalize(),
            'probability': f"{final_prediction_prob:.2f}"
        })

    except Exception as e:
        print(f"Đã xảy ra lỗi trong quá trình dự đoán: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({'error': 'Đã xảy ra lỗi không xác định trên máy chủ.'}), 500

# --- ĐIỂM BẮT ĐẦU CHẠY ỨNG DỤNG ---
# Đoạn mã này được cấu hình để hoạt động tốt trên cả máy local và Hugging Face Spaces.
if __name__ == '__main__':
    # Hugging Face Spaces sẽ đặt biến môi trường PORT. Nếu không có, dùng 7860 làm mặc định.
    port = int(os.environ.get("PORT", 7860))
    # Chạy trên host '0.0.0.0' để ứng dụng có thể được truy cập từ bên ngoài container Docker.
    app.run(host='0.0.0.0', port=port)