g8a9's picture
add minimal structure and parsing cv17 results
ad108b7
raw
history blame
3.25 kB
import gradio as gr
import pandas as pd
import random
import plotly.express as px
from huggingface_hub import snapshot_download
import os
import logging
from config import (
SETUPS,
LOCAL_RESULTS_DIR,
CITATION_BUTTON_TEXT,
CITATION_BUTTON_LABEL,
)
from parsing import read_all_configs
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler("app.log"),
logging.StreamHandler()
],
)
logger = logging.getLogger(__name__)
try:
print("Saving results locally at:", LOCAL_RESULTS_DIR)
snapshot_download(
repo_id="g8a9/fair-asr-results",
local_dir=LOCAL_RESULTS_DIR,
repo_type="dataset",
tqdm_class=None,
etag_timeout=30,
ignore_patterns=["*samples*", "*transcripts*"],
token=os.environ.get("TOKEN"),
)
except Exception as e:
raise e
def format_dataframe(df, times_100=False):
if times_100:
df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x))
else:
df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x))
return df
with gr.Blocks() as fm_interface:
fm = SETUPS[0]
setup = fm["majority_group"] + "_" + fm["minority_group"]
results = read_all_configs(setup)
model_results = (
results.pivot_table(
index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum()
)
.reset_index()
.sort_values("Gap")
)
best_model = model_results.iloc[0]["Model"]
print("Best model:", best_model)
# model_results = format_dataframe(model_results)
# print(results.head())
gr.Markdown("### Sum of Absolute Gaps ⬇️")
gr.DataFrame(format_dataframe(model_results))
gr.Markdown("#### F-M gaps by language")
lang_results = results.pivot_table(
index="Model",
values="Gap",
columns="Language",
).reset_index()
gr.DataFrame(format_dataframe(lang_results, times_100=True))
# gr.Plot(fig1)
results["Gap"] = results["Gap"] * 100
fig = px.bar(
results,
x="Language",
y="Gap",
color="Model",
title="Gaps by Language and Model",
labels={
"Gap": "Sum of Absolute Gaps (%)",
"Language": "Language",
"Model": "Model",
},
barmode="group",
)
lang_order = (
lang_results.set_index("Model")
.loc[best_model]
.sort_values(ascending=False)
.index
)
print(lang_order)
# [best_model].sort_values().index
fig.update_layout(xaxis={"categoryorder": "array", "categoryarray": lang_order})
gr.Plot(fig)
# gr.Plot(fig2)
tabs = [fm_interface]
titles = ["F-M Setup"]
with gr.Blocks() as demo:
gr.Markdown("# Twists, Humps, and Pebbles: ASR Leadeboard")
gr.Markdown(
"""
Datasets currently included:
- **Mozilla Common Voice v17**
"""
)
gr.TabbedInterface(tabs, titles)
gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
max_lines=6,
show_copy_button=True,
)
if __name__ == "__main__":
demo.launch()