|
import json |
|
import os |
|
import shutil |
|
import sys |
|
from collections import defaultdict |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.metrics import confusion_matrix |
|
|
|
from utils import compute_average_wer, download_dataset |
|
|
|
|
|
def main(): |
|
""" |
|
Main function to orchestrate the multilingual data generation process. |
|
|
|
This function performs the following steps: |
|
1. Downloads multilingual evaluation data if requested. |
|
2. Processes multilingual evaluation files. |
|
3. Calculates and saves results, including Word Error Rate (WER) and |
|
language detection confusion matrices. |
|
""" |
|
source_repo = "argmaxinc/whisperkit-evals-multilingual" |
|
source_subfolder = "WhisperKit" |
|
source_directory = f"{source_repo}/{source_subfolder}" |
|
if len(sys.argv) > 1 and sys.argv[1] == "download": |
|
try: |
|
shutil.rmtree(source_repo) |
|
except: |
|
print("Nothing to remove.") |
|
download_dataset(source_repo, source_repo, source_subfolder) |
|
|
|
results = defaultdict( |
|
lambda: { |
|
"average_wer": [], |
|
"language_wer": defaultdict(list), |
|
"language_detection": [], |
|
} |
|
) |
|
|
|
confusion_matrices = {} |
|
|
|
for subdir, _, files in os.walk(source_directory): |
|
for filename in files: |
|
if not filename.endswith(".json") or "summary" in filename: |
|
continue |
|
|
|
file_path = os.path.join(subdir, filename) |
|
with open(file_path, "r") as f: |
|
data = json.load(f) |
|
|
|
subdir_components = subdir.split(os.path.sep) |
|
is_forced = "forced" in subdir_components |
|
model = subdir_components[-3] if not is_forced else subdir_components[-4] |
|
|
|
key = f"{model}/{'forced' if is_forced else 'not_forced'}" |
|
|
|
for item in data["results"]: |
|
if "reference_language" not in item: |
|
continue |
|
reference_language = item["reference_language"] |
|
wer = item["wer"] |
|
detected_language = item["predicted_language"] |
|
|
|
result = { |
|
"reference": item["reference"], |
|
"prediction": item["prediction"], |
|
} |
|
|
|
results[key]["average_wer"].append(result) |
|
results[key]["language_wer"][reference_language].append(result) |
|
results[key]["language_detection"].append( |
|
(reference_language, detected_language) |
|
) |
|
|
|
calculate_and_save_results(results, confusion_matrices) |
|
|
|
|
|
def calculate_and_save_results(results, confusion_matrices): |
|
""" |
|
Calculates final multilingual metrics and saves them to CSV and JSON files. |
|
|
|
:param results: Dictionary containing raw multilingual evaluation data. |
|
:param confusion_matrices: Dictionary to store confusion matrices for language detection. |
|
|
|
This function processes the raw multilingual data, calculates average metrics, |
|
creates confusion matrices for language detection, and saves the results to: |
|
1. A CSV file with WER data for each model and language. |
|
2. A JSON file with confusion matrices for language detection. |
|
""" |
|
wer_data = [] |
|
for key, data in results.items(): |
|
model, forced = key.rsplit("/", 1) |
|
model = model.replace("_", "/") |
|
row = { |
|
"Model": model, |
|
"Forced Tokens": forced == "forced", |
|
"Average WER": compute_average_wer(data["average_wer"]), |
|
} |
|
for lang, wers in data["language_wer"].items(): |
|
row[f"WER_{lang}"] = compute_average_wer(wers) |
|
wer_data.append(row) |
|
|
|
true_languages, detected_languages = zip(*data["language_detection"]) |
|
unique_languages = sorted(set(true_languages)) |
|
cm = confusion_matrix( |
|
true_languages, detected_languages, labels=unique_languages |
|
) |
|
|
|
row_sums = cm.sum(axis=1) |
|
cm_normalized = np.zeros_like(cm, dtype=float) |
|
non_zero_rows = row_sums != 0 |
|
cm_normalized[non_zero_rows] = ( |
|
cm[non_zero_rows] / row_sums[non_zero_rows, np.newaxis] |
|
) |
|
|
|
if model not in confusion_matrices: |
|
confusion_matrices[model] = {} |
|
confusion_matrices[model][forced] = { |
|
"matrix": cm_normalized.tolist(), |
|
"labels": unique_languages, |
|
} |
|
|
|
df = pd.DataFrame(wer_data) |
|
df.to_csv("dashboard_data/multilingual_results.csv", index=False) |
|
|
|
with open("dashboard_data/multilingual_confusion_matrices.json", "w") as f: |
|
json.dump(confusion_matrices, f, indent=2) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|