File size: 7,319 Bytes
9d85ee2 8a02493 eed20cf 8a02493 9d85ee2 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 8a02493 3e3e17d eed20cf 74e4942 8a02493 3155f54 65f6dc4 3155f54 65f6dc4 8a02493 3155f54 8a02493 eed20cf 8a02493 65f6dc4 e676bd8 3155f54 e676bd8 3155f54 e676bd8 3155f54 e67d080 e676bd8 8a02493 ea82efc 3155f54 e67d080 0673498 3155f54 b2ec3f0 3155f54 3e3e17d 3155f54 eed20cf 3155f54 3843f4e e676bd8 e67d080 e676bd8 3155f54 8a02493 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import os
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
from nltk import ngrams
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_statistics(model="large-v2", round_dp=2, ngram_degree=5):
text1 = norm_target
if model == "large-v2":
text2 = norm_pred_v2
elif model == "large-32-2":
text2 = norm_pred_32_2
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references]), round_dp
)
all_ngrams = list(ngrams(" ".join(text2).split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
return wer_percentage, ier_percentage, repeated_ngrams
def get_overall_table():
large_v2 = get_statistics(model="large-v2")
large_32_2 = get_statistics(model="large-32-2")
# format the rows
table = [large_v2, large_32_2]
# format the model names
table[0] = ["large-v2", *table[0]]
table[1] = ["large-32-2", *table[1]]
return table
def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
if model == "large-v2":
text2 = norm_pred_v2[idx]
elif model == "large-32-2":
text2 = norm_pred_32_2[idx]
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / len(wer_output.references[0]), round_dp
)
all_ngrams = list(ngrams(text2.split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (
(sampling_rate, array),
wer_percentage,
ier_percentage,
repeated_ngrams,
full_text,
)
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="large-v2")
large_32_2 = get_visualisation(idx, model="large-32-2")
# format the rows
table = [large_v2[1:-1], large_32_2[1:-1]]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Whisper Transcription Analysis
</h1>
</div>
</div>
"""
)
gr.Markdown(
"Analyse the transcriptions generated by the Whisper and Distil-Whisper models on the TEDLIUM dev set. "
"Analysis is performed on the overall level, where statistics are computed over the entire dev set, and also a per-sample level. "
"The transcriptions for both models are shown at the bottom of the demo. The text diff for each is computed "
"relative to the target transcriptions, where insertions are displayed in <span style='background-color:Lightgreen'>green</span>, and "
"deletions in <span style='background-color:#FFCCCB'><s>red</s></span>."
)
gr.Markdown("**Overall statistics:**")
table = gr.Dataframe(
value=get_overall_table(),
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
row_count=2,
)
gr.Markdown("**Per-sample statistics:**")
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Column():
table = gr.Dataframe(
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
row_count=2,
)
with gr.Row():
gr.Markdown("**Whisper text diff**")
gr.Markdown("**Distil-Whisper text diff**")
with gr.Row():
text_out_v2 = gr.Markdown(label="Text difference")
text_out_32_2 = gr.Markdown(label="Text difference")
btn.click(
fn=get_side_by_side_visualisation,
inputs=slider,
outputs=[audio_out, table, text_out_v2, text_out_32_2],
)
demo.launch()
|