Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import random | |
import plotly.express as px | |
from huggingface_hub import snapshot_download | |
import os | |
import logging | |
from config import ( | |
SETUPS, | |
LOCAL_RESULTS_DIR, | |
CITATION_BUTTON_TEXT, | |
CITATION_BUTTON_LABEL, | |
) | |
from parsing import read_all_configs | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[ | |
# logging.FileHandler("app.log"), | |
logging.StreamHandler() | |
], | |
) | |
logger = logging.getLogger(__name__) | |
try: | |
print("Saving results locally at:", LOCAL_RESULTS_DIR) | |
snapshot_download( | |
repo_id="g8a9/fair-asr-results", | |
local_dir=LOCAL_RESULTS_DIR, | |
repo_type="dataset", | |
tqdm_class=None, | |
etag_timeout=30, | |
ignore_patterns=["*samples*", "*transcripts*"], | |
token=os.environ.get("TOKEN"), | |
) | |
except Exception as e: | |
raise e | |
def format_dataframe(df, times_100=False): | |
if times_100: | |
df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x)) | |
else: | |
df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x)) | |
return df | |
with gr.Blocks() as fm_interface: | |
fm = SETUPS[0] | |
setup = fm["majority_group"] + "_" + fm["minority_group"] | |
results = read_all_configs(setup) | |
model_results = ( | |
results.pivot_table( | |
index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum() | |
) | |
.reset_index() | |
.sort_values("Gap") | |
) | |
best_model = model_results.iloc[0]["Model"] | |
print("Best model:", best_model) | |
# model_results = format_dataframe(model_results) | |
# print(results.head()) | |
gr.Markdown("### Sum of Absolute Gaps ⬇️") | |
gr.DataFrame(format_dataframe(model_results)) | |
gr.Markdown("#### F-M gaps by language") | |
lang_results = results.pivot_table( | |
index="Model", | |
values="Gap", | |
columns="Language", | |
).reset_index() | |
gr.DataFrame(format_dataframe(lang_results, times_100=True)) | |
# gr.Plot(fig1) | |
results["Gap"] = results["Gap"] * 100 | |
fig = px.bar( | |
results, | |
x="Language", | |
y="Gap", | |
color="Model", | |
title="Gaps by Language and Model", | |
labels={ | |
"Gap": "Sum of Absolute Gaps (%)", | |
"Language": "Language", | |
"Model": "Model", | |
}, | |
barmode="group", | |
) | |
lang_order = ( | |
lang_results.set_index("Model") | |
.loc[best_model] | |
.sort_values(ascending=False) | |
.index | |
) | |
print(lang_order) | |
# [best_model].sort_values().index | |
fig.update_layout(xaxis={"categoryorder": "array", "categoryarray": lang_order}) | |
gr.Plot(fig) | |
# gr.Plot(fig2) | |
tabs = [fm_interface] | |
titles = ["F-M Setup"] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Twists, Humps, and Pebbles: ASR Leadeboard") | |
gr.Markdown( | |
""" | |
Datasets currently included: | |
- **Mozilla Common Voice v17** | |
""" | |
) | |
gr.TabbedInterface(tabs, titles) | |
gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
label=CITATION_BUTTON_LABEL, | |
max_lines=6, | |
show_copy_button=True, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |