Update app.py
Browse files
app.py
CHANGED
@@ -12,18 +12,18 @@ transform = transforms.Compose([
|
|
12 |
transforms.RemoveWhiteSpace(replace_by_space=True),
|
13 |
])
|
14 |
|
15 |
-
|
16 |
dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
|
17 |
references = {row["id"]: row["text"] for row in dataset}
|
18 |
|
19 |
-
|
20 |
leaderboard_file = "leaderboard.csv"
|
21 |
if not os.path.exists(leaderboard_file):
|
22 |
pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
|
23 |
|
24 |
def process_submission(submitter_name, csv_file):
|
25 |
try:
|
26 |
-
|
27 |
df = pd.read_csv(csv_file)
|
28 |
if set(df.columns) != {"id", "text"}:
|
29 |
return "Error: CSV must contain exactly 'id' and 'text' columns.", None
|
@@ -32,7 +32,7 @@ def process_submission(submitter_name, csv_file):
|
|
32 |
if set(df["id"]) != set(references.keys()):
|
33 |
return "Error: CSV 'id's must match the dataset 'id's.", None
|
34 |
|
35 |
-
|
36 |
wers, cers = [], []
|
37 |
for _, row in df.iterrows():
|
38 |
ref = references[row["id"]]
|
@@ -40,11 +40,11 @@ def process_submission(submitter_name, csv_file):
|
|
40 |
wers.append(wer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
|
41 |
cers.append(cer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
|
42 |
|
43 |
-
|
44 |
avg_wer = sum(wers) / len(wers)
|
45 |
avg_cer = sum(cers) / len(cers)
|
46 |
|
47 |
-
|
48 |
leaderboard = pd.read_csv(leaderboard_file)
|
49 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
50 |
new_entry = pd.DataFrame(
|
@@ -58,7 +58,7 @@ def process_submission(submitter_name, csv_file):
|
|
58 |
except Exception as e:
|
59 |
return f"Error processing submission: {str(e)}", None
|
60 |
|
61 |
-
|
62 |
with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
|
63 |
gr.Markdown(
|
64 |
"""
|
|
|
12 |
transforms.RemoveWhiteSpace(replace_by_space=True),
|
13 |
])
|
14 |
|
15 |
+
# Load the Bambara ASR dataset
|
16 |
dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
|
17 |
references = {row["id"]: row["text"] for row in dataset}
|
18 |
|
19 |
+
# Load or initialize the leaderboard
|
20 |
leaderboard_file = "leaderboard.csv"
|
21 |
if not os.path.exists(leaderboard_file):
|
22 |
pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
|
23 |
|
24 |
def process_submission(submitter_name, csv_file):
|
25 |
try:
|
26 |
+
# Read and validate the uploaded CSV
|
27 |
df = pd.read_csv(csv_file)
|
28 |
if set(df.columns) != {"id", "text"}:
|
29 |
return "Error: CSV must contain exactly 'id' and 'text' columns.", None
|
|
|
32 |
if set(df["id"]) != set(references.keys()):
|
33 |
return "Error: CSV 'id's must match the dataset 'id's.", None
|
34 |
|
35 |
+
# Calculate WER and CER for each prediction
|
36 |
wers, cers = [], []
|
37 |
for _, row in df.iterrows():
|
38 |
ref = references[row["id"]]
|
|
|
40 |
wers.append(wer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
|
41 |
cers.append(cer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
|
42 |
|
43 |
+
# Compute average WER and CER
|
44 |
avg_wer = sum(wers) / len(wers)
|
45 |
avg_cer = sum(cers) / len(cers)
|
46 |
|
47 |
+
# Update the leaderboard
|
48 |
leaderboard = pd.read_csv(leaderboard_file)
|
49 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
50 |
new_entry = pd.DataFrame(
|
|
|
58 |
except Exception as e:
|
59 |
return f"Error processing submission: {str(e)}", None
|
60 |
|
61 |
+
# Create the Gradio interface
|
62 |
with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
|
63 |
gr.Markdown(
|
64 |
"""
|