real_or_fake / app.py
unknown
ADD application file
47ce5f0
raw
history blame
3.05 kB
import gradio as gr
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
import torch
import torchaudio
from io import BytesIO
# Hugging Face Model Hub'dan modelni yuklash
model_name = "Mrkomiljon/voiceGUARD/wav2vec2_finetuned_model" # Hugging Face Model Hub'dagi modelning to'liq nomi
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
model.eval()
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define label mapping
id2label = {
0: "diffwave",
1: "melgan",
2: "parallel_wave_gan",
3: "Real",
4: "wavegrad",
5: "wavnet",
6: "wavernn"
}
# Define the prediction function
def predict_audio(file):
target_sample_rate = 16000 # Model's expected sample rate
max_length = target_sample_rate * 10 # 10 seconds in samples
try:
# Load the audio file
audio_bytes = file.read()
waveform, sample_rate = torchaudio.load(BytesIO(audio_bytes))
# Resample if the sample rate doesn't match the model's expected rate
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
# Truncate or pad the waveform to ensure consistent input length
if waveform.size(1) > max_length:
waveform = waveform[:, :max_length] # Truncate
elif waveform.size(1) < max_length:
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.size(1))) # Pad
if waveform.ndim > 1:
waveform = waveform[0]
# Process the audio file
inputs = processor(
waveform.squeeze().numpy(),
sampling_rate=target_sample_rate,
return_tensors="pt",
padding=True
)
input_values = inputs["input_values"].to(device)
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0, predicted_label].item()
# Map label to class name
class_name = id2label.get(predicted_label, "Unknown Class")
return {
"Class": class_name,
"Confidence": round(confidence * 100, 2)
}
except Exception as e:
return {"error": f"Error processing the audio file: {str(e)}"}
# Create the Gradio interface
iface = gr.Interface(
fn=predict_audio,
inputs=gr.Audio(type="file"),
outputs=[
gr.Label(label="Predicted Class"),
gr.Label(label="Confidence")
],
title="Audio Classification with Wav2Vec2",
description="Upload an audio file to classify it into one of the predefined categories."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()