Upload 5 files
Browse files- README.md +76 -13
- app.py +149 -0
- requirements.txt +3 -0
- submissions.db +0 -0
- testsets/spaccc_gender_dataset_test.csv +0 -0
README.md
CHANGED
@@ -1,13 +1,76 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text Classification Benchmark Leaderboard
|
2 |
+
|
3 |
+
This project provides a **leaderboard** for evaluating **Text Classification** models. Users can upload their model predictions in a CSV format, compare performance metrics against ground truth datasets, and track submissions over time.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
- **Dataset Selection:** Users can choose a dataset from predefined test sets.
|
7 |
+
- **Submission Upload:** Supports CSV files with `file_name` and `label` columns.
|
8 |
+
- **Automated Evaluation:** Calculates **Accuracy, Precision, Recall, and F1-score**.
|
9 |
+
- **Leaderboard Tracking:** Stores and displays past experiments.
|
10 |
+
- **Gradio Interface:** Simple and interactive web interface.
|
11 |
+
|
12 |
+
## Requirements
|
13 |
+
Ensure you have the following installed before running the project:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
pip install gradio pandas sqlalchemy scikit-learn
|
17 |
+
```
|
18 |
+
|
19 |
+
## Setup & Usage
|
20 |
+
|
21 |
+
1. **Clone the repository:**
|
22 |
+
```bash
|
23 |
+
git clone https://github.com/nlp4bia-bsc/text-classification-leaderboard.git
|
24 |
+
cd text-classification-leaderboard
|
25 |
+
```
|
26 |
+
|
27 |
+
2. **Run the application:**
|
28 |
+
```bash
|
29 |
+
python app.py
|
30 |
+
```
|
31 |
+
|
32 |
+
3. **Access the interface:**
|
33 |
+
The application runs locally. Open your browser and go to:
|
34 |
+
```
|
35 |
+
http://127.0.0.1:7860/
|
36 |
+
```
|
37 |
+
|
38 |
+
## Submission Format
|
39 |
+
Your submission file must be a **CSV** containing the following columns:
|
40 |
+
|
41 |
+
| file_name | label |
|
42 |
+
|-----------|--------|
|
43 |
+
| doc1.txt | spam |
|
44 |
+
| doc2.txt | ham |
|
45 |
+
| doc3.txt | spam |
|
46 |
+
|
47 |
+
### Evaluation Metrics
|
48 |
+
The system calculates:
|
49 |
+
- **Accuracy**
|
50 |
+
- **Precision (weighted)**
|
51 |
+
- **Recall (weighted)**
|
52 |
+
- **F1-score (weighted)**
|
53 |
+
|
54 |
+
## Directory Structure
|
55 |
+
```
|
56 |
+
text-classification-leaderboard/
|
57 |
+
│── testsets/ # Folder containing test datasets
|
58 |
+
│── submissions.db # SQLite database for storing results
|
59 |
+
│── app.py # Main application script
|
60 |
+
│── README.md # Project documentation
|
61 |
+
```
|
62 |
+
|
63 |
+
## Future Improvements
|
64 |
+
- Add support for multi-label classification.
|
65 |
+
- Expand dataset compatibility with more formats.
|
66 |
+
|
67 |
+
## License
|
68 |
+
This project is licensed under the **MIT License**. Feel free to contribute and enhance it!
|
69 |
+
|
70 |
+
## Contributing
|
71 |
+
Pull requests are welcome! If you have suggestions or find issues, please open an issue on the repository.
|
72 |
+
|
73 |
+
---
|
74 |
+
**Author:** Wesam Alnabki
|
75 |
+
**GitHub:** [wesamalnabki](https://github.com/wesamalnabki)
|
76 |
+
|
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime
|
4 |
+
from sqlalchemy.ext.declarative import declarative_base
|
5 |
+
from sqlalchemy.orm import sessionmaker
|
6 |
+
from datetime import datetime
|
7 |
+
import pandas as pd
|
8 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
9 |
+
|
10 |
+
testsets_root_path = "./testsets/"
|
11 |
+
|
12 |
+
# Function to load the dataset
|
13 |
+
def load_testsets(testsets_root_path: str) -> dict:
|
14 |
+
datasets_dict = {}
|
15 |
+
for ds in os.listdir(testsets_root_path):
|
16 |
+
if ds.endswith(".csv"): # Ensure only CSV files are processed
|
17 |
+
csv_path = os.path.join(testsets_root_path, ds)
|
18 |
+
df = pd.read_csv(csv_path)
|
19 |
+
datasets_dict[ds.replace(".csv", "")] = df
|
20 |
+
return datasets_dict
|
21 |
+
|
22 |
+
# Database setup
|
23 |
+
Base = declarative_base()
|
24 |
+
|
25 |
+
class Submission(Base):
|
26 |
+
__tablename__ = 'submissions'
|
27 |
+
id = Column(Integer, primary_key=True)
|
28 |
+
dataset_name = Column(String)
|
29 |
+
submission_name = Column(String)
|
30 |
+
model_link = Column(String)
|
31 |
+
person_name = Column(String)
|
32 |
+
accuracy = Column(Float)
|
33 |
+
precision = Column(Float)
|
34 |
+
recall = Column(Float)
|
35 |
+
f1 = Column(Float)
|
36 |
+
submission_date = Column(DateTime, default=datetime.utcnow)
|
37 |
+
|
38 |
+
engine = create_engine('sqlite:///submissions.db')
|
39 |
+
Base.metadata.create_all(engine)
|
40 |
+
Session = sessionmaker(bind=engine)
|
41 |
+
session = Session()
|
42 |
+
|
43 |
+
# Function to fetch previous submissions for a selected dataset
|
44 |
+
def get_existing_submissions(dataset_name):
|
45 |
+
existing_submissions = session.query(Submission).filter_by(dataset_name=dataset_name).order_by(
|
46 |
+
Submission.submission_date.desc()).all()
|
47 |
+
|
48 |
+
submissions_list = [{
|
49 |
+
"Submission Name": sub.submission_name,
|
50 |
+
"Model Link": sub.model_link,
|
51 |
+
"Person Name": sub.person_name,
|
52 |
+
"Accuracy": sub.accuracy,
|
53 |
+
"Precision": sub.precision,
|
54 |
+
"Recall": sub.recall,
|
55 |
+
"F1": sub.f1,
|
56 |
+
"Submission Date": sub.submission_date.strftime("%Y-%m-%d %H:%M:%S")
|
57 |
+
} for sub in existing_submissions]
|
58 |
+
|
59 |
+
return pd.DataFrame(submissions_list) if submissions_list else pd.DataFrame(columns=[
|
60 |
+
"Submission Name", "Model Link", "Person Name", "Accuracy", "Precision", "Recall", "F1", "Submission Date"
|
61 |
+
])
|
62 |
+
|
63 |
+
# Evaluation function for text classification
|
64 |
+
def calculate_metrics(gs, pred):
|
65 |
+
y_true = gs['label']
|
66 |
+
y_pred = pred['label']
|
67 |
+
try:
|
68 |
+
accuracy = accuracy_score(y_true, y_pred)
|
69 |
+
precision = precision_score(y_true, y_pred, average='weighted')
|
70 |
+
recall = recall_score(y_true, y_pred, average='weighted')
|
71 |
+
f1 = f1_score(y_true, y_pred, average='weighted')
|
72 |
+
return accuracy, precision, recall, f1
|
73 |
+
except:
|
74 |
+
return None, None, None, None
|
75 |
+
|
76 |
+
def benchmark_interface(dataset_name, submission_file, submission_name, model_link, person_name):
|
77 |
+
if not all([dataset_name, submission_file, submission_name, model_link, person_name]):
|
78 |
+
return {"error": "All fields are required."}, pd.DataFrame()
|
79 |
+
|
80 |
+
dataset_dict = load_testsets(testsets_root_path)
|
81 |
+
df_gs = dataset_dict.get(dataset_name)
|
82 |
+
if df_gs is None:
|
83 |
+
return {"error": "Dataset not found."}, pd.DataFrame()
|
84 |
+
|
85 |
+
# Parse the uploaded submission CSV
|
86 |
+
submission_df = pd.read_csv(submission_file.name)
|
87 |
+
|
88 |
+
# Ensure the columns are present
|
89 |
+
if not all(col in submission_df.columns for col in ['file_name', 'label']):
|
90 |
+
return {"error": "Submission file must contain 'file_name' and 'label' columns."}, pd.DataFrame()
|
91 |
+
|
92 |
+
# Calculate metrics
|
93 |
+
accuracy, precision, recall, f1 = calculate_metrics(gs=df_gs, pred=submission_df)
|
94 |
+
metrics = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1}
|
95 |
+
if f1 is not None:
|
96 |
+
# Save submission to the database
|
97 |
+
new_submission = Submission(
|
98 |
+
dataset_name=dataset_name,
|
99 |
+
submission_name=submission_name,
|
100 |
+
model_link=model_link,
|
101 |
+
person_name=person_name,
|
102 |
+
accuracy=accuracy,
|
103 |
+
precision=precision,
|
104 |
+
recall=recall,
|
105 |
+
f1=f1
|
106 |
+
)
|
107 |
+
session.add(new_submission)
|
108 |
+
session.commit()
|
109 |
+
|
110 |
+
# Fetch updated submissions
|
111 |
+
submissions_df = get_existing_submissions(dataset_name)
|
112 |
+
return metrics, submissions_df
|
113 |
+
|
114 |
+
|
115 |
+
def create_gradio_app():
|
116 |
+
dataset_dict = load_testsets(testsets_root_path)
|
117 |
+
dataset_names = list(dataset_dict.keys())
|
118 |
+
|
119 |
+
with gr.Blocks() as demo:
|
120 |
+
gr.Markdown("## Benchmarking Leaderboard for Text Classification")
|
121 |
+
dataset_radio = gr.Radio(choices=dataset_names, label="Select Dataset")
|
122 |
+
submission_file = gr.File(label="Upload Submission CSV")
|
123 |
+
submission_name = gr.Textbox(label="Submission Name")
|
124 |
+
model_link = gr.Textbox(label="Model Link on HuggingFace")
|
125 |
+
person_name = gr.Textbox(label="Person Name")
|
126 |
+
submit_button = gr.Button("Submit")
|
127 |
+
metrics_output = gr.JSON(label="Evaluation Metrics")
|
128 |
+
existing_submissions_output = gr.Dataframe(label="Existing Submissions")
|
129 |
+
|
130 |
+
# When a dataset is selected, fetch previous submissions
|
131 |
+
dataset_radio.change(
|
132 |
+
fn=get_existing_submissions,
|
133 |
+
inputs=[dataset_radio],
|
134 |
+
outputs=[existing_submissions_output]
|
135 |
+
)
|
136 |
+
|
137 |
+
submit_button.click(
|
138 |
+
fn=benchmark_interface,
|
139 |
+
inputs=[dataset_radio, submission_file, submission_name, model_link, person_name],
|
140 |
+
outputs=[metrics_output, existing_submissions_output]
|
141 |
+
)
|
142 |
+
return demo
|
143 |
+
|
144 |
+
def main():
|
145 |
+
app = create_gradio_app()
|
146 |
+
app.launch()
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
sqlalchemy
|
2 |
+
pandas
|
3 |
+
scikit-learn
|
submissions.db
ADDED
Binary file (8.19 kB). View file
|
|
testsets/spaccc_gender_dataset_test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|