File size: 7,371 Bytes
9ba8fab 3769468 9ba8fab 3769468 6960dc6 5d4699a 9ba8fab c726970 9ba8fab 3769468 58f7be4 3769468 c726970 3769468 c726970 3769468 9ba8fab 58f7be4 9ba8fab 3769468 c726970 29c8f24 c726970 9ba8fab c726970 3769468 c726970 3769468 c726970 6960dc6 c726970 9ba8fab c726970 9ba8fab c726970 3769468 c726970 9ba8fab c726970 9ba8fab c726970 58f7be4 c726970 9ba8fab c726970 9ba8fab c726970 9ba8fab c726970 9ba8fab 3769468 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import pandas as pd
from datasets import load_dataset
from jiwer import wer, cer
import os
from datetime import datetime
import re
# Load the Bambara ASR dataset
dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
references = {row["id"]: row["text"] for row in dataset}
# Initialize leaderboard file if it doesn't exist
leaderboard_file = "leaderboard.csv"
if not os.path.exists(leaderboard_file):
pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
def normalize_text(text):
"""
Normalize text for WER/CER calculation:
- Convert to lowercase
- Remove punctuation
- Replace multiple spaces with single space
- Strip leading/trailing spaces
"""
if not isinstance(text, str):
text = str(text)
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def calculate_metrics(predictions_df):
"""
Calculate WER and CER for predictions against the reference dataset.
"""
results = []
for _, row in predictions_df.iterrows():
id_val = row["id"]
if id_val not in references:
continue
reference = normalize_text(references[id_val])
hypothesis = normalize_text(row["text"])
if not reference or not hypothesis:
continue
try:
sample_wer = wer(reference, hypothesis)
sample_cer = cer(reference, hypothesis)
results.append({
"id": id_val,
"wer": sample_wer,
"cer": sample_cer
})
except Exception:
pass # Skip invalid samples silently
if not results:
raise ValueError("No valid samples available for metric calculation")
avg_wer = sum(item["wer"] for item in results) / len(results)
avg_cer = sum(item["cer"] for item in results) / len(results)
return avg_wer, avg_cer, results
def process_submission(submitter_name, csv_file):
"""
Process the uploaded CSV, calculate metrics, and update the leaderboard.
"""
try:
df = pd.read_csv(csv_file)
if len(df) == 0:
return "Submission failed: The uploaded CSV file is empty. Please upload a valid CSV file with predictions.", None
if set(df.columns) != {"id", "text"}:
return f"Submission failed: The CSV file must contain exactly two columns: 'id' and 'text'. Found: {', '.join(df.columns)}", None
if df["id"].duplicated().any():
dup_ids = df[df["id"].duplicated(keep=False)]["id"].unique()
return f"Submission failed: Duplicate 'id' values detected: {', '.join(map(str, dup_ids[:5]))}", None
missing_ids = set(references.keys()) - set(df["id"])
extra_ids = set(df["id"]) - set(references.keys())
if missing_ids:
return f"Submission failed: Missing {len(missing_ids)} required 'id' values. First few: {', '.join(map(str, list(missing_ids)[:5]))}", None
if extra_ids:
return f"Submission failed: Found {len(extra_ids)} unrecognized 'id' values. First few: {', '.join(map(str, list(extra_ids)[:5]))}", None
empty_ids = [row["id"] for _, row in df.iterrows() if not normalize_text(row["text"])]
if empty_ids:
return f"Submission failed: Empty transcriptions detected for {len(empty_ids)} 'id' values. First few: {', '.join(map(str, empty_ids[:5]))}", None
# Calculate metrics
avg_wer, avg_cer, detailed_results = calculate_metrics(df)
n_valid = len(detailed_results)
if n_valid == 0:
return "Submission failed: No valid samples found for metric calculation.", None
# Update leaderboard
leaderboard = pd.read_csv(leaderboard_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
new_entry = pd.DataFrame(
[[submitter_name, avg_wer, avg_cer, timestamp]],
columns=["submitter", "WER", "CER", "timestamp"]
)
leaderboard = pd.concat([leaderboard, new_entry]).sort_values("WER")
leaderboard.to_csv(leaderboard_file, index=False)
# Format leaderboard for display
display_leaderboard = leaderboard.copy()
display_leaderboard["WER"] = display_leaderboard["WER"].apply(lambda x: f"{x:.4f}")
display_leaderboard["CER"] = display_leaderboard["CER"].apply(lambda x: f"{x:.4f}")
return f"Your submission has been successfully processed. Evaluated {n_valid} valid samples. WER: {avg_wer:.4f}, CER: {avg_cer:.4f}", display_leaderboard
except Exception as e:
return f"Submission failed: An error occurred while processing your file - {str(e)}", None
def load_and_format_leaderboard():
"""
Load the leaderboard and format WER/CER for display.
"""
if os.path.exists(leaderboard_file):
leaderboard = pd.read_csv(leaderboard_file)
leaderboard["WER"] = leaderboard["WER"].apply(lambda x: f"{x:.4f}")
leaderboard["CER"] = leaderboard["CER"].apply(lambda x: f"{x:.4f}")
return leaderboard
return pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"])
# Gradio interface
with gr.Blocks(title="Bambara ASR Benchmark Leaderboard") as demo:
gr.Markdown(
"""
## Bambara ASR Benchmark Leaderboard
**Welcome to the Bambara Automatic Speech Recognition (ASR) Benchmark Leaderboard**
Evaluate your ASR model's performance on the Bambara language dataset.
### Submission Instructions
1. Prepare a CSV file with two columns:
- **`id`**: Must match identifiers in the official dataset.
- **`text`**: Your model's transcription predictions.
2. Ensure the CSV file meets these requirements:
- Contains only `id` and `text` columns.
- No duplicate `id` values.
- All `id` values match dataset entries.
3. Upload your CSV file below.
### Dataset
Access the official dataset: [Bambara ASR Dataset](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset)
### Evaluation Metrics
- **Word Error Rate (WER)**: Word-level transcription accuracy (lower is better).
- **Character Error Rate (CER)**: Character-level accuracy (lower is better).
### Leaderboard
Submissions are ranked by WER and include:
- Submitter name
- WER (4 decimal places)
- CER (4 decimal places)
- Submission timestamp
"""
)
with gr.Row():
submitter = gr.Textbox(label="Submitter Name or Model Identifier", placeholder="e.g., MALIBA-AI/asr")
csv_upload = gr.File(label="Upload Prediction CSV File", file_types=[".csv"])
submit_btn = gr.Button("Evaluate Submission")
output_msg = gr.Textbox(label="Submission Status", interactive=False)
leaderboard_display = gr.DataFrame(
label="Current Leaderboard",
value=load_and_format_leaderboard(),
interactive=False
)
submit_btn.click(
fn=process_submission,
inputs=[submitter, csv_upload],
outputs=[output_msg, leaderboard_display]
)
if __name__ == "__main__":
demo.launch(share=True) |