aliasgerovs's picture
Update app.py
01418a2 verified
import gradio as gr
import torch
import torchaudio
from speechbrain.pretrained import SpeakerRecognition
import torch.nn as nn
from transformers import AutoModel
import os
from huggingface_hub import hf_hub_download
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
speaker_model = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="tmp",
run_opts={"device": device}
)
class PretrainedTransformerClassifier(nn.Module):
def __init__(self, num_classes=3):
super().__init__()
self.transformer = AutoModel.from_pretrained('distilbert/distilroberta-base')
for param in self.transformer.parameters():
param.requires_grad = False
for param in self.transformer.encoder.layer[-2:].parameters():
param.requires_grad = True
self.embed_projection = nn.Sequential(
nn.Linear(1, 768),
nn.LayerNorm(768),
nn.Dropout(0.1)
)
self.classifier = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.embed_projection(x.unsqueeze(-1))
if len(x.shape) == 2:
x = x.unsqueeze(0)
attention_mask = torch.ones((x.shape[0], x.shape[1])).to(x.device)
transformer_output = self.transformer(
inputs_embeds=x,
attention_mask=attention_mask,
return_dict=True
)
pooled_output = transformer_output.last_hidden_state[:, 0, :]
return self.classifier(pooled_output)
# Load the model from Huggingface Hub
def load_model():
model_path = hf_hub_download(repo_id="polygraf-ai/vexon-voice-authentication", filename="model.pth")
config_path = hf_hub_download(repo_id="polygraf-ai/vexon-voice-authentication", filename="config.pth")
config = torch.load(config_path)
classifier = PretrainedTransformerClassifier(num_classes=config['num_classes']).to(device)
classifier.load_state_dict(torch.load(model_path, map_location=device))
classifier.eval()
return classifier, config
classifier, model_config = load_model()
def extract_embedding(audio_path):
try:
signal, fs = torchaudio.load(audio_path)
signal = signal.to(device)
embedding = speaker_model.encode_batch(signal)
return embedding.cpu().detach().numpy().flatten()
except Exception as e:
print(f"Error processing {audio_path}: {e}")
return None
def verify_speaker(audio_path1, audio_path2):
emb1 = extract_embedding(audio_path1)
emb2 = extract_embedding(audio_path2)
if emb1 is None or emb2 is None:
return None
tensor1 = torch.tensor(emb1).to(device)
tensor2 = torch.tensor(emb2).to(device)
similarity_score = torch.nn.functional.cosine_similarity(
tensor1, tensor2, dim=0
).cpu().item()
return similarity_score
def process_audio(audio1, audio2):
"""
Process two audio files and return authentication results using the pretrained transformer classifier
"""
if audio1 is None or audio2 is None:
return "Please upload both audio files."
temp_path1 = "temp_audio1.wav"
temp_path2 = "temp_audio2.wav"
try:
torchaudio.save(temp_path1,
torchaudio.load(audio1)[0],
torchaudio.load(audio1)[1])
torchaudio.save(temp_path2,
torchaudio.load(audio2)[0],
torchaudio.load(audio2)[1])
score = verify_speaker(temp_path1, temp_path2)
if score is None:
return "Error processing audio files. Please ensure they are valid audio recordings."
with torch.no_grad():
score_tensor = torch.FloatTensor([[score]]).to(device)
output = classifier(score_tensor)
prediction = torch.argmax(output, dim=1).item()
probabilities = torch.softmax(output, dim=1)[0]
confidence = probabilities[prediction].item()
label = "Original" if prediction == 0 else "Deepfake/Impersonation"
result = f"""
πŸ“Š Authentication Results:
πŸ”Ή Similarity Score: {score:.4f}
πŸ”Ή Classification: {label}
πŸ”Ή Confidence: {confidence:.4f}
{'⚠️ Potential Voice Impersonation Detected!' if prediction > 0 else 'βœ… Authentic Voice Match'}
"""
return result
except Exception as e:
return f"An error occurred: {str(e)}"
finally:
if os.path.exists(temp_path1):
os.remove(temp_path1)
if os.path.exists(temp_path2):
os.remove(temp_path2)
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-radius: 8px;
background: linear-gradient(to right, #2125ff, #4146ff);
border: none;
cursor: pointer;
}
.gr-button:hover {
background: linear-gradient(to right, #1f23e6, #3b40e6);
}
.footer {
margin-top: 20px;
text-align: center;
border-top: 1px solid #ccc;
padding-top: 10px;
}
"""
demo = gr.Interface(
fn=process_audio,
inputs=[
gr.Audio(label="Reference Voice Recording", type="filepath"),
gr.Audio(label="Voice Recording to Verify", type="filepath")
],
outputs=gr.Textbox(label="Authentication Results"),
title="Vexon Voice Authentication",
description="""
Upload two voice recordings to verify if they are from the same person and detect potential voice impersonation attempts.""",
css=css
)
if __name__ == "__main__":
demo.launch()