File size: 4,201 Bytes
9ba8fab
 
 
60c60cf
9ba8fab
 
 
5d4699a
9ba8fab
 
 
5d4699a
9ba8fab
 
 
 
60c60cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba8fab
 
5d4699a
9ba8fab
3efa4cc
29c8f24
 
3efa4cc
9ba8fab
 
3efa4cc
9ba8fab
 
3efa4cc
5d4699a
9ba8fab
 
3efa4cc
60c60cf
 
3efa4cc
60c60cf
 
 
3efa4cc
60c60cf
 
 
 
3efa4cc
5d4699a
60c60cf
 
 
9ba8fab
 
 
5d4699a
9ba8fab
 
 
 
 
 
 
 
 
 
3efa4cc
9ba8fab
 
 
5d4699a
9ba8fab
 
 
 
3efa4cc
 
9ba8fab
 
 
 
 
3efa4cc
9ba8fab
3bdb09a
9ba8fab
3efa4cc
9ba8fab
 
 
 
 
 
 
 
 
 
 
 
 
 
29c8f24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import pandas as pd
from datasets import load_dataset
from jiwer import wer, cer
import os
from datetime import datetime

# Load the Bambara ASR dataset
dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
references = {row["id"]: row["text"] for row in dataset}

# Load or initialize the leaderboard
leaderboard_file = "leaderboard.csv"
if not os.path.exists(leaderboard_file):
    pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)

def preprocess_text(text):
    """
    Custom text preprocessing to handle Bambara text properly
    """
    # Convert to string in case it's not
    text = str(text)
    
    # Remove punctuation
    for punct in [',', '.', '!', '?', ';', ':', '"', "'"]:
        text = text.replace(punct, '')
    
    # Convert to lowercase
    text = text.lower()
    
    # Normalize whitespace
    text = ' '.join(text.split())
    
    return text

def process_submission(submitter_name, csv_file):
    try:
        # Read and validate the uploaded CSV
        df = pd.read_csv(csv_file)
        
        if set(df.columns) != {"id", "text"}:
            return "Error: CSV must contain exactly 'id' and 'text' columns.", None
            
        if df["id"].duplicated().any():
            return "Error: Duplicate 'id's found in the CSV.", None
            
        if set(df["id"]) != set(references.keys()):
            return "Error: CSV 'id's must match the dataset 'id's.", None
            
        # Calculate WER and CER for each prediction
        wers, cers = [], []
        
        for _, row in df.iterrows():
            ref = preprocess_text(references[row["id"]])
            pred = preprocess_text(row["text"])
            
            # Check if either text is empty after preprocessing
            if not ref or not pred:
                continue
                
            # Calculate metrics with no transform (we did preprocessing already)
            # This avoids the error with jiwer's transform
            wers.append(wer(ref, pred))
            cers.append(cer(ref, pred))
            
        # Compute average WER and CER
        if not wers or not cers:
            return "Error: No valid text pairs for evaluation after preprocessing.", None
            
        avg_wer = sum(wers) / len(wers)
        avg_cer = sum(cers) / len(cers)
        
        # Update the leaderboard
        leaderboard = pd.read_csv(leaderboard_file)
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        new_entry = pd.DataFrame(
            [[submitter_name, avg_wer, avg_cer, timestamp]],
            columns=["submitter", "WER", "CER", "timestamp"]
        )
        leaderboard = pd.concat([leaderboard, new_entry]).sort_values("WER")
        leaderboard.to_csv(leaderboard_file, index=False)
        
        return "Submission processed successfully!", leaderboard
        
    except Exception as e:
        return f"Error processing submission: {str(e)}", None

# Create the Gradio interface
with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
    gr.Markdown(
        """
        # Bambara ASR Leaderboard
        Upload a CSV file with 'id' and 'text' columns to evaluate your ASR predictions.
        The 'id's must match those in the dataset.
        [View the dataset here](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset).
        - **WER**: Word Error Rate (lower is better).
        - **CER**: Character Error Rate (lower is better).
        """
    )
    
    with gr.Row():
        submitter = gr.Textbox(label="Submitter Name or Model Name", placeholder="e.g., MALIBA-AI/asr")
        csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
        
    submit_btn = gr.Button("Submit")
    output_msg = gr.Textbox(label="Status", interactive=False)
    leaderboard_display = gr.DataFrame(
        label="Leaderboard",
        value=pd.read_csv(leaderboard_file),
        interactive=False
    )
    
    submit_btn.click(
        fn=process_submission,
        inputs=[submitter, csv_upload],
        outputs=[output_msg, leaderboard_display]
    )

demo.launch(share=True)