bird-classifier / app.py
3v324v23's picture
Force environment rebuild
a870ca2
raw
history blame
8.64 kB
# app.py
# Phiên bản hoàn chỉnh, đã sửa lỗi đọc file audio và đồng bộ hóa phiên bản thư viện.
import os
import joblib
import numpy as np
import librosa
from flask import Flask, request, jsonify, render_template
from werkzeug.utils import secure_filename
import traceback
# --- Thư viện mới để đọc audio một cách mạnh mẽ ---
from pydub import AudioSegment
# --- Cấu hình TensorFlow và các thư viện AI ---
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch
# --- KHỞI TẠO ỨNG DỤNG FLASK ---
app = Flask(__name__)
# Cấu hình thư mục tạm để lưu file audio
UPLOAD_FOLDER = 'uploads/'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# --- TẢI TẤT CẢ CÁC MÔ HÌNH KHI SERVER KHỞI ĐỘNG ---
print(">>> Đang tải các mô hình AI, quá trình này có thể mất một lúc...")
try:
MODEL_PATH = 'models/'
scaler = joblib.load(os.path.join(MODEL_PATH, 'scaler.pkl'))
label_encoder = joblib.load(os.path.join(MODEL_PATH, 'label_encoder.pkl'))
model_xgb = joblib.load(os.path.join(MODEL_PATH, 'xgboost.pkl'))
model_lgb = joblib.load(os.path.join(MODEL_PATH, 'lightgbm.pkl'))
model_cnn = tf.keras.models.load_model(os.path.join(MODEL_PATH, 'cnn.keras'))
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
print(">>> OK! Tất cả các mô hình đã được tải thành công!")
except Exception as e:
print(f"!!! LỖI NGHIÊM TRỌNG: Không thể tải một hoặc nhiều mô hình. Lỗi: {e}")
traceback.print_exc()
exit()
# --- CÁC HÀM TRÍCH XUẤT ĐẶC TRƯNG (KHÔNG ĐỔI) ---
# Các hằng số cấu hình
SAMPLE_RATE = 22050
MAX_LENGTH_SECONDS = 5.0
MAX_SAMPLES = int(SAMPLE_RATE * MAX_LENGTH_SECONDS)
N_MELS = 128
TRADITIONAL_FEATURE_SIZE = 570
WAV2VEC_FEATURE_SIZE = 768
SPECTROGRAM_SHAPE = (224, 224, 3)
def _extract_traditional_features(y, sr):
try:
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
features = np.mean(mel_spec_db, axis=1)
features = np.append(features, np.std(mel_spec_db, axis=1))
features = np.append(features, np.max(mel_spec_db, axis=1))
features = np.append(features, np.min(mel_spec_db, axis=1))
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
features = np.append(features, np.mean(mfccs, axis=1))
features = np.append(features, np.std(mfccs, axis=1))
if len(features) > TRADITIONAL_FEATURE_SIZE:
features = features[:TRADITIONAL_FEATURE_SIZE]
elif len(features) < TRADITIONAL_FEATURE_SIZE:
features = np.pad(features, (0, TRADITIONAL_FEATURE_SIZE - len(features)), mode='constant')
return features
except Exception as e:
print(f"Lỗi trích xuất đặc trưng truyền thống: {e}")
return np.zeros(TRADITIONAL_FEATURE_SIZE)
def _extract_wav2vec_features(y, sr):
try:
y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000)
inputs = wav2vec_processor(y_16k, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = wav2vec_model(**inputs)
features = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
return features
except Exception as e:
print(f"Lỗi trích xuất Wav2Vec2: {e}")
return np.zeros(WAV2VEC_FEATURE_SIZE)
def _create_spectrogram_image(y, sr):
try:
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=SPECTROGRAM_SHAPE[0])
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
mel_spec_norm = ((mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8) * 255).astype(np.uint8)
img = tf.keras.preprocessing.image.array_to_img(np.stack([mel_spec_norm]*3, axis=-1))
img = img.resize((SPECTROGRAM_SHAPE[1], SPECTROGRAM_SHAPE[0]))
return np.array(img)
except Exception as e:
print(f"Lỗi tạo ảnh spectrogram: {e}")
return np.zeros(SPECTROGRAM_SHAPE)
# --- HÀM XỬ LÝ AUDIO ĐÃ ĐƯỢC CẬP NHẬT ---
def process_audio_file(file_path):
"""
Hàm tổng hợp phiên bản mới: Dùng pydub để đọc file audio một cách mạnh mẽ,
sau đó chuyển đổi sang định dạng mà librosa có thể xử lý an toàn.
"""
try:
# 1. Dùng pydub để mở file audio (hỗ trợ nhiều định dạng)
audio = AudioSegment.from_file(file_path)
# 2. Đảm bảo audio là mono (1 kênh) và có sample rate đúng
audio = audio.set_channels(1)
audio = audio.set_frame_rate(SAMPLE_RATE)
# 3. Chuyển đổi audio của pydub thành mảng NumPy cho librosa
# Chuẩn hóa về khoảng [-1, 1]
samples = np.array(audio.get_array_of_samples()).astype(np.float32)
y = samples / (2**(audio.sample_width * 8 - 1))
# 4. Chuẩn hóa độ dài audio về MAX_SAMPLES
if len(y) > MAX_SAMPLES:
y = y[:MAX_SAMPLES]
else:
y = np.pad(y, (0, MAX_SAMPLES - len(y)), mode='constant')
# 5. Trích xuất đồng thời các bộ đặc trưng (code này không đổi)
traditional_features = _extract_traditional_features(y, SAMPLE_RATE)
wav2vec_features = _extract_wav2vec_features(y, SAMPLE_RATE)
spectrogram = _create_spectrogram_image(y, SAMPLE_RATE)
return traditional_features, wav2vec_features, spectrogram
except Exception as e:
print(f"Lỗi nghiêm trọng khi xử lý file audio {file_path}: {e}")
traceback.print_exc()
return None, None, None
# --- ĐỊNH NGHĨA CÁC ROUTE CỦA ỨNG DỤNG ---
@app.route('/', methods=['GET'])
def home():
"""Render trang chủ của ứng dụng."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""API endpoint để nhận file audio, xử lý và trả về kết quả dự đoán."""
if 'audio_file' not in request.files:
return jsonify({'error': 'Không có file audio nào trong yêu cầu.'}), 400
file = request.files['audio_file']
if file.filename == '':
return jsonify({'error': 'Tên file không hợp lệ.'}), 400
try:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Xử lý file audio để trích xuất tất cả các đặc trưng cần thiết
trad_feats, w2v_feats, spec_img = process_audio_file(filepath)
if trad_feats is None:
return jsonify({'error': 'Không thể xử lý file audio.'}), 500
# Chuẩn bị dữ liệu đầu vào cho các mô hình
combined_feats = np.concatenate([trad_feats, w2v_feats]).reshape(1, -1)
scaled_feats = scaler.transform(combined_feats)
spec_img = spec_img / 255.0
spec_img = np.expand_dims(spec_img, axis=0)
# Lấy dự đoán từ tất cả các mô hình
pred_xgb = model_xgb.predict_proba(scaled_feats)[0][1]
pred_lgb = model_lgb.predict_proba(scaled_feats)[0][1]
pred_cnn = model_cnn.predict(spec_img, verbose=0)[0][0]
# Ensemble: Kết hợp kết quả bằng cách lấy trung bình xác suất
final_prediction_prob = (pred_xgb + pred_lgb + pred_cnn) / 3
final_prediction_label_index = 1 if final_prediction_prob > 0.5 else 0
result_label_text = label_encoder.inverse_transform([final_prediction_label_index])[0]
os.remove(filepath)
print(f"Phân tích hoàn tất. Kết quả: {result_label_text.upper()} (Xác suất: {final_prediction_prob:.2f})")
return jsonify({
'prediction': result_label_text.capitalize(),
'probability': f"{final_prediction_prob:.2f}"
})
except Exception as e:
print(f"Đã xảy ra lỗi trong quá trình dự đoán: {e}")
traceback.print_exc()
return jsonify({'error': 'Đã xảy ra lỗi không xác định trên máy chủ.'}), 500
# --- ĐIỂM BẮT ĐẦU CHẠY ỨNG DỤNG ---
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(host='0.0.0.0', port=port)