Spaces:
Running
Running
| import json | |
| import os | |
| import shutil | |
| import sys | |
| from collections import defaultdict | |
| from statistics import mean | |
| import pandas as pd | |
| import requests | |
| from text_normalizer import text_normalizer | |
| from utils import compute_average_wer, download_dataset | |
| def fetch_evaluation_data(url): | |
| """ | |
| Fetches evaluation data from the given URL. | |
| :param url: The URL to fetch the evaluation data from. | |
| :returns: The evaluation data as a dictionary. | |
| :rauses: sys.exit if the request fails | |
| """ | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return json.loads(response.text) | |
| else: | |
| sys.exit(f"Failed to fetch WhisperKit evals: {response.text}") | |
| def get_device_name(device): | |
| """ | |
| Gets the device name from the device map if it exists. | |
| :param device: String representing the device name. | |
| :returns: The device name from the device map if it exists, otherwise the input device name. | |
| """ | |
| with open("dashboard_data/device_map.json", "r") as f: | |
| device_map = json.load(f) | |
| return device_map.get(device, device).replace(" ", "_") | |
| def process_quality_file(file_path, dataset_dfs, quality_results): | |
| """ | |
| Processes a single quality file and updates the quality_results dictionary. | |
| :param file_path: Path to the quality JSON file. | |
| :param dataset_dfs: Dictionary of DataFrames containing dataset information. | |
| :param quality_results: Dictionary to store the processed quality results. | |
| This function reads a quality JSON file, extracts relevant information, | |
| and updates the quality_results dictionary with various metrics including WER | |
| and Quality of Inference (QoI) for different datasets. | |
| """ | |
| with open(file_path, "r") as file: | |
| test_results = json.load(file) | |
| if len(test_results) == 0: | |
| return | |
| metadata = test_results["metadata"] | |
| test_results = test_results["results"] | |
| model = file_path.split("/")[-3].replace("_", "/") | |
| device = metadata["inference_context"]["device_spec"]["product_name"] | |
| device = get_device_name(device) | |
| timestamp = file_path.split("/")[-1].split(".")[0] | |
| key = model | |
| dataset_name = metadata["dataset_name"] | |
| for test_result in test_results: | |
| audio_file_name = test_result["file"] | |
| dataset_key = "Earnings-22" if "earnings22" in dataset_name else "LibriSpeech" | |
| dataset_df = dataset_dfs[dataset_key] | |
| wer_entry = { | |
| "prediction": text_normalizer(test_result["prediction"]), | |
| "reference": text_normalizer(test_result["reference"]), | |
| } | |
| quality_results[key]["timestamp"] = timestamp | |
| quality_results[key]["dataset_wer"][dataset_name].append(wer_entry) | |
| audio = audio_file_name.split(".")[0] | |
| dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0] | |
| reference_wer = dataset_row["wer"] | |
| prediction_wer = test_result["wer"] | |
| quality_results[key]["qoi"].append(1 if prediction_wer <= reference_wer else 0) | |
| def calculate_and_save_quality_results(quality_results, quality_output_path): | |
| """ | |
| Calculates final quality metrics and saves them to a JSON file. | |
| :param quality_results: Dictionary containing raw quality data. | |
| :param quality_output_path: Path to save the processed quality results. | |
| This function processes the raw quality data, calculates average metrics, | |
| and writes the final results to a JSON file, with each entry representing | |
| a unique model's quality metrics across different datasets, including | |
| Word Error Rate (WER) and Quality of Inference (QoI). | |
| """ | |
| with open(quality_output_path, "w") as quality_file: | |
| for key, data in quality_results.items(): | |
| model = key | |
| dataset_wers = { | |
| dataset: compute_average_wer(wer) | |
| for dataset, wer in data["dataset_wer"].items() | |
| } | |
| average_wer = ( | |
| sum(dataset_wers.values()) / len(dataset_wers) | |
| if len(dataset_wers) != 0 | |
| else 0 | |
| ) | |
| quality_entry = { | |
| "model": model.replace("_", "/"), | |
| "timestamp": data["timestamp"], | |
| "average_wer": round(average_wer, 2), | |
| "dataset_wer": dataset_wers, | |
| "qoi": round(mean(data["qoi"]), 2), | |
| } | |
| json.dump(quality_entry, quality_file) | |
| quality_file.write("\n") | |
| def main(): | |
| """ | |
| Main function to orchestrate the quality data generation process. | |
| This function performs the following steps: | |
| 1. Downloads quality data if requested. | |
| 2. Fetches evaluation data for various datasets. | |
| 3. Processes quality files for specific datasets. | |
| 4. Calculates and saves quality results, including WER and QoI metrics. | |
| """ | |
| if len(sys.argv) > 1 and sys.argv[1] == "download": | |
| try: | |
| shutil.rmtree("english") | |
| except: | |
| print("Nothing to remove.") | |
| download_dataset("argmaxinc/whisperkit-evals", "english", "WhisperKit") | |
| datasets = { | |
| "Earnings-22": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", | |
| "LibriSpeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", | |
| "earnings22-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", | |
| "librispeech-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", | |
| "earnings22-12hours": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", | |
| "librispeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", | |
| } | |
| dataset_dfs = {} | |
| for dataset_name, url in datasets.items(): | |
| evals = fetch_evaluation_data(url) | |
| dataset_dfs[dataset_name] = pd.json_normalize(evals["results"]) | |
| source_quality_directory = "argmaxinc/english/WhisperKit/" | |
| quality_results = defaultdict( | |
| lambda: { | |
| "average_wer": [], | |
| "dataset_wer": defaultdict(list), | |
| "qoi": [], | |
| "timestamp": None, | |
| } | |
| ) | |
| for subdir, _, files in os.walk(source_quality_directory): | |
| dataset = subdir.split("/")[-1] | |
| if dataset not in ["earnings22-12hours", "librispeech"]: | |
| continue | |
| for filename in files: | |
| if not filename.endswith(".json"): | |
| continue | |
| file_path = os.path.join(subdir, filename) | |
| process_quality_file(file_path, dataset_dfs, quality_results) | |
| calculate_and_save_quality_results( | |
| quality_results, "dashboard_data/quality_data.json" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |