|
from flask import Flask, render_template, request, jsonify |
|
from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForXVector |
|
import torchaudio |
|
import torch |
|
import io |
|
import librosa |
|
from scipy.spatial.distance import cosine |
|
import numpy as np |
|
import os |
|
|
|
|
|
|
|
app = Flask(__name__, static_url_path='/static') |
|
|
|
|
|
mp3_file_path = "arnold.mp3" |
|
|
|
|
|
mp3_file_path2 = 'arnold2.wav' |
|
|
|
flag1="" |
|
flag2="" |
|
|
|
with open("flag1.txt") as f: |
|
flag1=f.read() |
|
with open("flag2.txt") as f: |
|
flag2=f.read() |
|
|
|
|
|
themodel = "microsoft/unispeech-sat-large-sv" |
|
if os.path.exists("model"): |
|
themodel = "model" |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(themodel) |
|
model = UniSpeechSatForXVector.from_pretrained(themodel) |
|
|
|
|
|
def preprocess_audio(audio_data): |
|
waveform, sample_rate = torchaudio.load(audio_data) |
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
if sample_rate != 16000: |
|
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) |
|
waveform = waveform.squeeze().numpy() |
|
return waveform |
|
|
|
@app.route('/') |
|
def index(): |
|
return render_template('index.html') |
|
|
|
@app.route('/chal2') |
|
def chal2(): |
|
return render_template('chal2.html') |
|
|
|
|
|
|
|
|
|
|
|
@app.route('/compare_audio', methods=['POST']) |
|
def compare_audio(): |
|
try: |
|
|
|
recorded_audio = request.files['audio_data'] |
|
|
|
|
|
audio_data = preprocess_audio(recorded_audio) |
|
inputs = feature_extractor(audio_data, return_tensors="pt") |
|
embeddings = model(**inputs).embeddings |
|
embeddings_normalized = torch.nn.functional.normalize(embeddings, dim=-1).cpu() |
|
|
|
|
|
mp3_audio = preprocess_audio(mp3_file_path) |
|
mp3_inputs = feature_extractor(mp3_audio, return_tensors="pt") |
|
mp3_embeddings = model(**mp3_inputs).embeddings |
|
mp3_embeddings_normalized = torch.nn.functional.normalize(mp3_embeddings, dim=-1).cpu() |
|
|
|
|
|
cosine_sim = torch.nn.CosineSimilarity(dim=-1) |
|
similarity = cosine_sim(embeddings_normalized, mp3_embeddings_normalized).item() |
|
|
|
similarity = round(similarity, 3) |
|
|
|
threshold = 0.89 |
|
if similarity < threshold: |
|
result = "Authorization Failed! " + str(similarity) + " < 0.890<br>Do your best Terminator impression" |
|
else: |
|
result = "Good job! Match: " + str(similarity) + "<br>" + flag1 + "<br><a href='/chal2'>Click here to open the next challenge</a>" |
|
|
|
return jsonify({'result': result}) |
|
except Exception as e: |
|
print("Caught: "+str(e)) |
|
return jsonify({'error': 'An error occurred during audio comparison. Im fragile please dont abuse.' }) |
|
|
|
def extract_mfcc(audio_bytes): |
|
|
|
waveform = preprocess_audio2(audio_bytes) |
|
|
|
|
|
mfcc = librosa.feature.mfcc(y=waveform, sr=16000, n_mfcc=13) |
|
|
|
return mfcc |
|
|
|
def preprocess_audio2(audio_bytes): |
|
|
|
waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
if sample_rate != 16000: |
|
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) |
|
|
|
|
|
waveform, _ = librosa.effects.trim(waveform, top_db=20) |
|
|
|
waveform = waveform.squeeze().numpy() |
|
|
|
return waveform |
|
|
|
@app.route('/compare_audio2', methods=['POST']) |
|
def compare_audio2(): |
|
try: |
|
recorded_audio = request.files['audio_data'].read() |
|
mp3_audio = open(mp3_file_path2, 'rb').read() |
|
|
|
|
|
mfcc1 = extract_mfcc(recorded_audio) |
|
mfcc2 = extract_mfcc(mp3_audio) |
|
similarity = 1 - cosine(np.mean(mfcc1, axis=1), np.mean(mfcc2, axis=1)) |
|
similarity = round(similarity, 3) |
|
if similarity < 0.940: |
|
result = "Authorization Failed! " + str(similarity) + " < 0.940<br>Say: 'With great power comes great responsibility' as Arnold Schwarzenegger" |
|
else: |
|
result = "Good job! Match: " + str(similarity) + "<br>" + flag2 |
|
|
|
return jsonify({'result': result}) |
|
except Exception as e: |
|
print("Caught: "+str(e)) |
|
return jsonify({'error': 'An error occurred during audio comparison. Im fragile please dont abuse.'}) |
|
|
|
if __name__ == '__main__': |
|
app.run(host="0.0.0.0", port=8080, debug=True) |