jerpint commited on
Commit
72c20ae
β€’
1 Parent(s): 554d877
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import requests
5
+ import random
6
+
7
+ labels = ["Real Audio πŸ—£οΈ", "Cloned Audio πŸ€–"]
8
+ DURATION = 2
9
+
10
+
11
+ def get_accuracy(score_matrix) -> str:
12
+
13
+ correct = score_matrix[0][0] + score_matrix[1][1]
14
+ total = sum(score_matrix[0]) + sum(score_matrix[1])
15
+ if total == 0:
16
+ return ""
17
+
18
+ accuracy = correct / total * 100
19
+ return f"{accuracy:.2f}%"
20
+
21
+
22
+ def audio_link(path: str, model: str):
23
+ """Get the link to the audio file for a given path and model."""
24
+ return f"https://huggingface.co/datasets/jerpint/vox-cloned-data/resolve/main/{model}/{path}?download=true"
25
+
26
+
27
+ def confusion_matrix_to_markdown(matrix, labels=None):
28
+ num_labels = len(matrix)
29
+ labels = labels or [f"Class {i}" for i in range(num_labels)]
30
+ accuracy = get_accuracy(matrix)
31
+
32
+ # Header row
33
+ markdown = f"| {' | '.join([''] + labels)} |\n"
34
+ markdown += f"| {' | '.join(['---'] * (num_labels + 1))} |\n"
35
+
36
+ # Data rows
37
+ for i, row in enumerate(matrix):
38
+ markdown += f"| {labels[i]} | " + " | ".join(map(str, row)) + " |\n"
39
+
40
+ markdown += f"\nAccuracy %: {accuracy}\n"
41
+
42
+ return markdown
43
+
44
+
45
+ def load_and_cache_data():
46
+ json_link = "https://huggingface.co/datasets/jerpint/vox-cloned-data/resolve/main/files.json?download=true"
47
+ local_file = "files.json"
48
+
49
+ if not os.path.exists(local_file):
50
+ json_file = requests.get(json_link)
51
+ if json_file.status_code != 200:
52
+ raise Exception(f"Failed to load data from {json_link}")
53
+
54
+ # Cache the file
55
+ with open(local_file, "w") as f:
56
+ f.write(json_file.text)
57
+
58
+ with open(local_file, "r") as f:
59
+ return json.load(f)
60
+
61
+
62
+ def load_data():
63
+ json_link = "https://huggingface.co/datasets/jerpint/vox-cloned-data/resolve/main/files.json?download=true"
64
+ json_file = requests.get(json_link)
65
+ if json_file.status_code != 200:
66
+ raise Exception(f"Failed to load data from {json_link}")
67
+ print("Loaded data")
68
+ return json.loads(json_file.text)
69
+
70
+
71
+ def select_random_model(path):
72
+ """Select a random model from the list of models for a given path.
73
+ Will select commonvoice 50% of the time, and a random other model 50% of the time.
74
+ """
75
+ if random.random() < 0.5:
76
+ return "commonvoice"
77
+ else:
78
+ other_models = [m for m in data[path] if m != "commonvoice"]
79
+ return random.choice(other_models)
80
+
81
+
82
+ def get_random_audio():
83
+ path = random.choice(paths)
84
+ model = select_random_model(path)
85
+ return path, model
86
+
87
+
88
+ def next_audio():
89
+ new_audio = get_random_audio()
90
+ audio_cmp = gr.Audio(audio_link(new_audio[0], new_audio[1]))
91
+ return audio_cmp, new_audio
92
+
93
+
94
+ data = load_data()
95
+
96
+ # Keep only samples with minimum 2 sources
97
+ data = {path: data[path] for path in data if len(data[path]) >= 2}
98
+
99
+ # List all available paths
100
+ paths = list(data.keys())
101
+
102
+
103
+ with gr.Blocks() as demo:
104
+ current_audio = gr.State(get_random_audio)
105
+ score_matrix = gr.State([[0, 0], [0, 0]])
106
+
107
+ with gr.Column():
108
+ with gr.Row():
109
+ audio_cmp = gr.Audio(
110
+ audio_link(current_audio.value[0], current_audio.value[1])
111
+ )
112
+ with gr.Column():
113
+ with gr.Row():
114
+ button1 = gr.Button("Real Audio πŸ—£οΈ")
115
+ button2 = gr.Button("Cloned Audio πŸ€–")
116
+
117
+ score_md = gr.Markdown(confusion_matrix_to_markdown(score_matrix.value, labels))
118
+
119
+ @gr.on(
120
+ triggers=[button1.click],
121
+ inputs=[current_audio, score_matrix],
122
+ outputs=[audio_cmp, current_audio, score_matrix, score_md],
123
+ )
124
+ def check_result(x, score_matrix):
125
+ is_correct = x[1] == "commonvoice"
126
+ audio_cmp, current_audio = next_audio()
127
+ if is_correct:
128
+ gr.Info("Correct! Real Audio", duration=DURATION)
129
+ score_matrix[0][0] += 1
130
+ else:
131
+ gr.Warning("Incorrect! Cloned Audio", duration=DURATION)
132
+ score_matrix[0][1] += 1
133
+
134
+ score_md = confusion_matrix_to_markdown(score_matrix, labels)
135
+ return audio_cmp, current_audio, score_matrix, score_md
136
+
137
+ @gr.on(
138
+ triggers=[button2.click],
139
+ inputs=[current_audio, score_matrix],
140
+ outputs=[audio_cmp, current_audio, score_matrix, score_md],
141
+ )
142
+ def check_result(x, score_matrix):
143
+ is_correct = x[1] != "commonvoice"
144
+ audio_cmp, current_audio = next_audio()
145
+ if is_correct:
146
+ gr.Info("Correct! Cloned Audio", duration=DURATION)
147
+ score_matrix[1][1] += 1
148
+ else:
149
+ gr.Warning("Incorrect! Real Audio", duration=DURATION)
150
+ score_matrix[1][0] += 1
151
+ score_md = confusion_matrix_to_markdown(score_matrix, labels)
152
+ return audio_cmp, current_audio, score_matrix, score_md
153
+
154
+
155
+ demo.launch()