File size: 6,073 Bytes
9d85ee2 e676bd8 9d85ee2 8a02493 9d85ee2 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 74e4942 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 9d85ee2 8a02493 e676bd8 8a02493 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
from functools import partial
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_visualisation(idx, model="v2"):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
text2 = norm_pred_v2[idx] if model == "v2" else norm_pred_32_2[idx]
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, 2)
ier_percentage = round(100 * wer_output.insertions / len(wer_output.references[0]), 2)
rel_length = round(len(text2.split()) / len(text1.split()), 2)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (sampling_rate, array), wer_percentage, ier_percentage, rel_length, full_text
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="v2")
large_32_2 = get_visualisation(idx, model="32-2")
table = [large_v2[1:-1], large_32_2[1:-1]]
table[0] = ["large-v2", *table[0]]
table[1] = ["large-32-2", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Tab("large-v2"):
gr.Markdown(
"Analyse the transcriptions generated by the Whisper large-v2 model on the TEDLIUM dev set."
)
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Row():
wer = gr.Number(label="Word Error Rate (WER)")
ier = gr.Number(
label="Insertion Error Rate (IER)"
)
relative_length = gr.Number(
label="Relative length (reference length / target length)"
)
text_out = gr.Markdown(label="Text difference")
btn.click(
fn=partial(get_visualisation, model="v2"),
inputs=slider,
outputs=[audio_out, wer, ier, relative_length, text_out],
)
with gr.Tab("large-32-2"):
gr.Markdown(
"Analyse the transcriptions generated by the Whisper large-32-2 model on the TEDLIUM dev set."
)
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Row():
wer = gr.Number(label="Word Error Rate (WER)")
ier = gr.Number(
label="Insertion Error Rate (IER)"
)
relative_length = gr.Number(
label="Relative length (reference length / target length)"
)
text_out = gr.Markdown(label="Text difference")
btn.click(
fn=partial(get_visualisation, model="32-2"),
inputs=slider,
outputs=[audio_out, wer, ier, relative_length, text_out],
)
with gr.Tab("side-by-side"):
gr.Markdown(
"Analyse the transcriptions generated by the Whisper large-32-2 model on the TEDLIUM dev set."
)
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Column():
table = gr.Dataframe(headers=["Model", "Word Error Rate (WER)", "Insertion Error Rate (IER)", "Rel length (ref length / tgt length)"], height=1000)
with gr.Row():
gr.Markdown("large-v2 text diff")
gr.Markdown("large-32-2 text diff")
with gr.Row():
text_out_v2 = gr.Markdown(label="Text difference")
text_out_32_2 = gr.Markdown(label="Text difference")
btn.click(
fn=get_side_by_side_visualisation,
inputs=slider,
outputs=[audio_out, table, text_out_v2, text_out_32_2],
)
demo.launch()
|