g8a9's picture
enhance data processing and visualization: add support for common languages and improve handling of models with NaN values
86e679c
raw
history blame
5.78 kB
import gradio as gr
import pandas as pd
import random
import plotly.express as px
from huggingface_hub import snapshot_download
import os
import logging
from config import (
SETUPS,
LOCAL_RESULTS_DIR,
CITATION_BUTTON_TEXT,
CITATION_BUTTON_LABEL,
)
from parsing import read_all_configs, get_common_langs
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler("app.log"),
logging.StreamHandler()
],
)
logger = logging.getLogger(__name__)
try:
print("Saving results locally at:", LOCAL_RESULTS_DIR)
snapshot_download(
repo_id="g8a9/fair-asr-results",
local_dir=LOCAL_RESULTS_DIR,
repo_type="dataset",
tqdm_class=None,
etag_timeout=30,
ignore_patterns=["*samples*", "*transcripts*"],
token=os.environ.get("TOKEN"),
)
except Exception as e:
raise e
def format_dataframe(df, times_100=False):
if times_100:
df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x))
else:
df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x))
return df
def _build_models_with_nan_md(models_with_nan):
model_markups = [f"*{m}*" for m in models_with_nan]
return f"""
We are currently hiding the results of {', '.join(model_markups)} because they don't support all languages.
"""
def build_components(show_common_langs):
aggregated_df, lang_df, barplot_fig, models_with_nan = _populate_components(
show_common_langs
)
models_with_nan_md = _build_models_with_nan_md(models_with_nan)
return (
gr.DataFrame(format_dataframe(aggregated_df)),
gr.DataFrame(format_dataframe(lang_df, times_100=True)),
gr.Plot(barplot_fig),
gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
)
def _populate_components(show_common_langs):
fm = SETUPS[0]
setup = fm["majority_group"] + "_" + fm["minority_group"]
results = read_all_configs(setup)
if show_common_langs:
common_langs = get_common_langs()
results = results[results["Language"].isin(common_langs)]
models_with_nan = results[results.isna().any(axis=1)]["Model"].unique().tolist()
logger.info(f"Models with NaN values: {models_with_nan}")
results = results[~results["Model"].isin(models_with_nan)]
aggregated_df = (
results.pivot_table(
index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum()
)
.reset_index()
.sort_values("Gap")
)
best_model = aggregated_df.iloc[0]["Model"]
top_3_models = aggregated_df["Model"].head(3).tolist()
# main_df = gr.DataFrame(format_dataframe(model_results))
lang_df = results.pivot_table(
index="Model",
values="Gap",
columns="Language",
).reset_index()
# lang_df = gr.DataFrame(format_dataframe(lang_results, times_100=True))
# gr.Plot(fig1)
results["Gap"] = results["Gap"] * 100
barplot_fig = px.bar(
results.loc[results["Model"].isin(top_3_models)],
x="Language",
y="Gap",
color="Model",
title="Gaps by Language and Model (top 3, sorted by the best model)",
labels={
"Gap": "Sum of Absolute Gaps (%)",
"Language": "Language",
"Model": "Model",
},
barmode="group",
)
lang_order = (
lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
)
logger.info(f"Lang order: {lang_order}")
barplot_fig.update_layout(
xaxis={"categoryorder": "array", "categoryarray": lang_order}
)
return aggregated_df, lang_df, barplot_fig, models_with_nan
with gr.Blocks() as fm_interface:
aggregated_df, lang_df, barplot_fig, model_with_nan = _populate_components(
show_common_langs=False
)
model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))
gr.Markdown("### Sum of Absolute Gaps ⬇️")
aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))
gr.Markdown("#### F-M gaps by language")
lang_df_comp = gr.DataFrame(format_dataframe(lang_df, times_100=True))
barplot_fig_comp = gr.Plot(barplot_fig)
###################
# LIST MAIN TABS
###################
tabs = [fm_interface]
titles = ["F-M Setup"]
banner = """
<style>
.full-width-image {
width: 100%;
height: auto;
margin: 0;
padding: 0;
}
</style>
<div>
<img src="https://huggingface.co/spaces/g8a9/fair-asr-leaderboard/raw/main/twists_banner.png" alt="Twists Banner" class="full-width-image">
</div>
"""
###################
# MAIN INTERFACE
###################
with gr.Blocks() as demo:
gr.HTML(banner)
gr.Markdown("# Fair ASR Leadeboard")
with gr.Row() as config_row:
show_common_langs = gr.CheckboxGroup(
choices=["Show only common languages"],
label="Main configuration",
)
include_datasets = gr.CheckboxGroup(
choices=["Mozilla CV 17"],
label="Include datasets",
value=["Mozilla CV 17"],
interactive=False,
)
show_common_langs.input(
build_components,
inputs=[show_common_langs],
outputs=[
aggregated_df_comp,
lang_df_comp,
barplot_fig_comp,
model_with_nans_md,
],
)
gr.TabbedInterface(tabs, titles)
gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
max_lines=6,
show_copy_button=True,
)
if __name__ == "__main__":
demo.launch()