Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import os | |
import re | |
from datetime import datetime | |
LEADERBOARD_FILE = "leaderboard.csv" # File to store leaderboard data | |
def clean_answer(answer): | |
if pd.isna(answer): | |
return None | |
answer = str(answer) | |
clean = re.sub(r'[^A-Da-d]', '', answer) | |
if clean: | |
first_letter = clean[0].upper() | |
if first_letter in ['A', 'B', 'C', 'D']: | |
return first_letter | |
return None | |
def write_evaluation_results(results, output_file): | |
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else '.', exist_ok=True) | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
output_text = [ | |
f"Evaluation Results for Model: {results['model_name']}", | |
f"Timestamp: {timestamp}", | |
"-" * 50, | |
f"Overall Accuracy (including invalid): {results['overall_accuracy']:.2%}", | |
f"Accuracy (valid predictions only): {results['valid_accuracy']:.2%}", | |
f"Total Questions: {results['total_questions']}", | |
f"Valid Predictions: {results['valid_predictions']}", | |
f"Invalid/Malformed Predictions: {results['invalid_predictions']}", | |
f"Correct Predictions: {results['correct_predictions']}", | |
"\nPerformance by Field:", | |
"-" * 50 | |
] | |
for field, metrics in results['field_performance'].items(): | |
field_results = [ | |
f"\nField: {field}", | |
f"Accuracy (including invalid): {metrics['accuracy']:.2%}", | |
f"Accuracy (valid only): {metrics['valid_accuracy']:.2%}", | |
f"Correct: {metrics['correct']}/{metrics['total']}", | |
f"Invalid predictions: {metrics['invalid']}" | |
] | |
output_text.extend(field_results) | |
with open(output_file, 'w') as f: | |
f.write('\n'.join(output_text)) | |
print('\n'.join(output_text)) | |
print(f"\nResults have been saved to: {output_file}") | |
def update_leaderboard(results): | |
# Add results to the leaderboard file | |
new_entry = { | |
"Model Name": results['model_name'], | |
"Overall Accuracy": f"{results['overall_accuracy']:.2%}", | |
"Valid Accuracy": f"{results['valid_accuracy']:.2%}", | |
"Correct Predictions": results['correct_predictions'], | |
"Total Questions": results['total_questions'], | |
"Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
leaderboard_df = pd.DataFrame([new_entry]) | |
if os.path.exists(LEADERBOARD_FILE): | |
existing_df = pd.read_csv(LEADERBOARD_FILE) | |
leaderboard_df = pd.concat([existing_df, leaderboard_df], ignore_index=True) | |
leaderboard_df.to_csv(LEADERBOARD_FILE, index=False) | |
def display_leaderboard(): | |
if not os.path.exists(LEADERBOARD_FILE): | |
return "Leaderboard is empty." | |
leaderboard_df = pd.read_csv(LEADERBOARD_FILE) | |
return leaderboard_df.to_markdown(index=False) | |
def evaluate_predictions(prediction_file): | |
ground_truth_file = "ground_truth.csv" # Specify the path to the ground truth file | |
if not prediction_file: | |
return "Prediction file not uploaded", None | |
if not os.path.exists(ground_truth_file): | |
return "Ground truth file not found", None | |
try: | |
predictions_df = pd.read_csv(prediction_file.name) | |
ground_truth_df = pd.read_csv(ground_truth_file) | |
# Extract model name | |
try: | |
filename = os.path.basename(prediction_file.name) | |
if "_" in filename and "." in filename: | |
model_name = filename.split('_')[1].split('.')[0] | |
else: | |
model_name = "unknown_model" | |
except IndexError: | |
model_name = "unknown_model" | |
# Merge dataframes | |
merged_df = pd.merge( | |
predictions_df, | |
ground_truth_df, | |
on='question_id', | |
how='inner' | |
) | |
merged_df['pred_answer'] = merged_df['predicted_answer'].apply(clean_answer) | |
invalid_predictions = merged_df['pred_answer'].isna().sum() | |
valid_predictions = merged_df.dropna(subset=['pred_answer']) | |
correct_predictions = (valid_predictions['pred_answer'] == valid_predictions['Answer']).sum() | |
total_predictions = len(merged_df) | |
total_valid_predictions = len(valid_predictions) | |
overall_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 | |
valid_accuracy = ( | |
correct_predictions / total_valid_predictions | |
if total_valid_predictions > 0 | |
else 0 | |
) | |
field_metrics = {} | |
for field in merged_df['Field'].unique(): | |
field_data = merged_df[merged_df['Field'] == field] | |
field_valid_data = field_data.dropna(subset=['pred_answer']) | |
field_correct = (field_valid_data['pred_answer'] == field_valid_data['Answer']).sum() | |
field_total = len(field_data) | |
field_valid_total = len(field_valid_data) | |
field_invalid = field_total - field_valid_total | |
field_metrics[field] = { | |
'accuracy': field_correct / field_total if field_total > 0 else 0, | |
'valid_accuracy': field_correct / field_valid_total if field_valid_total > 0 else 0, | |
'correct': field_correct, | |
'total': field_total, | |
'invalid': field_invalid | |
} | |
results = { | |
'model_name': model_name, | |
'overall_accuracy': overall_accuracy, | |
'valid_accuracy': valid_accuracy, | |
'total_questions': total_predictions, | |
'valid_predictions': total_valid_predictions, | |
'invalid_predictions': invalid_predictions, | |
'correct_predictions': correct_predictions, | |
'field_performance': field_metrics | |
} | |
update_leaderboard(results) | |
output_file = "evaluation_results.txt" | |
write_evaluation_results(results, output_file) | |
return "Evaluation completed successfully! Leaderboard updated.", output_file | |
except Exception as e: | |
return f"Error during evaluation: {str(e)}", None | |
# Gradio Interface | |
description = "Upload a prediction CSV file to evaluate predictions against the ground truth and update the leaderboard." | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# Prediction Evaluation Tool with Leaderboard") | |
with gr.Tab("Evaluate"): | |
file_input = gr.File(label="Upload Prediction CSV") | |
eval_status = gr.Textbox(label="Evaluation Status") | |
eval_results_file = gr.File(label="Download Evaluation Results") | |
eval_button = gr.Button("Evaluate") | |
eval_button.click( | |
evaluate_predictions, inputs=file_input, outputs=[eval_status, eval_results_file] | |
) | |
with gr.Tab("Leaderboard"): | |
leaderboard_text = gr.Textbox(label="Leaderboard", interactive=False) | |
refresh_button = gr.Button("Refresh Leaderboard") | |
refresh_button.click(display_leaderboard, outputs=leaderboard_text) | |
if __name__ == "__main__": | |
demo.launch() | |