jerpint's picture
add app
72c20ae
import os
import gradio as gr
import json
import requests
import random
labels = ["Real Audio πŸ—£οΈ", "Cloned Audio πŸ€–"]
DURATION = 2
def get_accuracy(score_matrix) -> str:
correct = score_matrix[0][0] + score_matrix[1][1]
total = sum(score_matrix[0]) + sum(score_matrix[1])
if total == 0:
return ""
accuracy = correct / total * 100
return f"{accuracy:.2f}%"
def audio_link(path: str, model: str):
"""Get the link to the audio file for a given path and model."""
return f"https://huggingface.co/datasets/jerpint/vox-cloned-data/resolve/main/{model}/{path}?download=true"
def confusion_matrix_to_markdown(matrix, labels=None):
num_labels = len(matrix)
labels = labels or [f"Class {i}" for i in range(num_labels)]
accuracy = get_accuracy(matrix)
# Header row
markdown = f"| {' | '.join([''] + labels)} |\n"
markdown += f"| {' | '.join(['---'] * (num_labels + 1))} |\n"
# Data rows
for i, row in enumerate(matrix):
markdown += f"| {labels[i]} | " + " | ".join(map(str, row)) + " |\n"
markdown += f"\nAccuracy %: {accuracy}\n"
return markdown
def load_and_cache_data():
json_link = "https://huggingface.co/datasets/jerpint/vox-cloned-data/resolve/main/files.json?download=true"
local_file = "files.json"
if not os.path.exists(local_file):
json_file = requests.get(json_link)
if json_file.status_code != 200:
raise Exception(f"Failed to load data from {json_link}")
# Cache the file
with open(local_file, "w") as f:
f.write(json_file.text)
with open(local_file, "r") as f:
return json.load(f)
def load_data():
json_link = "https://huggingface.co/datasets/jerpint/vox-cloned-data/resolve/main/files.json?download=true"
json_file = requests.get(json_link)
if json_file.status_code != 200:
raise Exception(f"Failed to load data from {json_link}")
print("Loaded data")
return json.loads(json_file.text)
def select_random_model(path):
"""Select a random model from the list of models for a given path.
Will select commonvoice 50% of the time, and a random other model 50% of the time.
"""
if random.random() < 0.5:
return "commonvoice"
else:
other_models = [m for m in data[path] if m != "commonvoice"]
return random.choice(other_models)
def get_random_audio():
path = random.choice(paths)
model = select_random_model(path)
return path, model
def next_audio():
new_audio = get_random_audio()
audio_cmp = gr.Audio(audio_link(new_audio[0], new_audio[1]))
return audio_cmp, new_audio
data = load_data()
# Keep only samples with minimum 2 sources
data = {path: data[path] for path in data if len(data[path]) >= 2}
# List all available paths
paths = list(data.keys())
with gr.Blocks() as demo:
current_audio = gr.State(get_random_audio)
score_matrix = gr.State([[0, 0], [0, 0]])
with gr.Column():
with gr.Row():
audio_cmp = gr.Audio(
audio_link(current_audio.value[0], current_audio.value[1])
)
with gr.Column():
with gr.Row():
button1 = gr.Button("Real Audio πŸ—£οΈ")
button2 = gr.Button("Cloned Audio πŸ€–")
score_md = gr.Markdown(confusion_matrix_to_markdown(score_matrix.value, labels))
@gr.on(
triggers=[button1.click],
inputs=[current_audio, score_matrix],
outputs=[audio_cmp, current_audio, score_matrix, score_md],
)
def check_result(x, score_matrix):
is_correct = x[1] == "commonvoice"
audio_cmp, current_audio = next_audio()
if is_correct:
gr.Info("Correct! Real Audio", duration=DURATION)
score_matrix[0][0] += 1
else:
gr.Warning("Incorrect! Cloned Audio", duration=DURATION)
score_matrix[0][1] += 1
score_md = confusion_matrix_to_markdown(score_matrix, labels)
return audio_cmp, current_audio, score_matrix, score_md
@gr.on(
triggers=[button2.click],
inputs=[current_audio, score_matrix],
outputs=[audio_cmp, current_audio, score_matrix, score_md],
)
def check_result(x, score_matrix):
is_correct = x[1] != "commonvoice"
audio_cmp, current_audio = next_audio()
if is_correct:
gr.Info("Correct! Cloned Audio", duration=DURATION)
score_matrix[1][1] += 1
else:
gr.Warning("Incorrect! Real Audio", duration=DURATION)
score_matrix[1][0] += 1
score_md = confusion_matrix_to_markdown(score_matrix, labels)
return audio_cmp, current_audio, score_matrix, score_md
demo.launch()