eval_model / app.py
AvocadoMuffin's picture
Update app.py
5ad87e4 verified
raw
history blame
15.6 kB
import os
import json
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
import torch
from sklearn.metrics import f1_score
import re
from collections import Counter
import string
from huggingface_hub import login
import gradio as gr
import pandas as pd
from datetime import datetime
def normalize_answer(s):
"""Normalize answer for evaluation"""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score_qa(prediction, ground_truth):
"""Calculate F1 score for QA"""
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
if len(prediction_tokens) == 0 or len(ground_truth_tokens) == 0:
return int(prediction_tokens == ground_truth_tokens)
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
"""Calculate exact match score"""
return normalize_answer(prediction) == normalize_answer(ground_truth)
def max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Calculate maximum score over all ground truth answers"""
scores = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores.append(score)
return max(scores) if scores else 0
def evaluate_model():
# Authenticate with Hugging Face using the token
hf_token = os.getenv("EVAL_TOKEN")
if hf_token:
try:
login(token=hf_token)
print("βœ“ Authenticated with Hugging Face")
except Exception as e:
print(f"⚠ Warning: Could not authenticate with HF token: {e}")
else:
print("⚠ Warning: EVAL_TOKEN not found in environment variables")
print("Loading model and tokenizer...")
model_name = "AvocadoMuffin/roberta-cuad-qa-v3"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForQuestionAnswering.from_pretrained(model_name, token=hf_token)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
print("βœ“ Model loaded successfully")
return qa_pipeline, hf_token
except Exception as e:
print(f"βœ— Error loading model: {e}")
return None, None
def inspect_dataset_structure(dataset, num_samples=3):
"""Inspect dataset structure for debugging"""
print(f"Dataset structure inspection:")
print(f"Dataset type: {type(dataset)}")
print(f"Dataset length: {len(dataset)}")
if len(dataset) > 0:
sample = dataset[0]
print(f"Sample keys: {list(sample.keys()) if isinstance(sample, dict) else 'Not a dict'}")
print(f"Sample structure:")
for key, value in sample.items():
print(f" {key}: {type(value)} - {str(value)[:100]}...")
return dataset
def run_evaluation(num_samples, progress=gr.Progress()):
"""Run evaluation and return results for Gradio interface"""
# Load model
qa_pipeline, hf_token = evaluate_model()
if qa_pipeline is None:
return "❌ Failed to load model", pd.DataFrame(), None
progress(0.1, desc="Loading CUAD dataset...")
# Load dataset - try multiple approaches
dataset = None
test_data = None
try:
# Try cuad dataset directly
print("Attempting to load CUAD dataset...")
dataset = load_dataset("cuad", token=hf_token)
test_data = dataset["test"]
print(f"βœ“ Loaded CUAD dataset with {len(test_data)} samples")
# Inspect structure
test_data = inspect_dataset_structure(test_data)
except Exception as e:
print(f"Error loading CUAD dataset: {e}")
try:
# Try squad format as fallback
print("Trying SQuAD format...")
dataset = load_dataset("squad", split="validation", token=hf_token)
test_data = dataset.select(range(min(1000, len(dataset))))
print(f"βœ“ Loaded SQuAD dataset as fallback with {len(test_data)} samples")
except Exception as e2:
return f"❌ Error loading any dataset: {e2}", pd.DataFrame(), None
if test_data is None:
return "❌ No test data available", pd.DataFrame(), None
# Limit samples
num_samples = min(num_samples, len(test_data))
test_subset = test_data.select(range(num_samples))
progress(0.2, desc=f"Starting evaluation on {num_samples} samples...")
# Initialize metrics
exact_matches = []
f1_scores = []
predictions = []
# Run evaluation
for i, example in enumerate(test_subset):
progress((0.2 + 0.7 * i / num_samples), desc=f"Processing sample {i+1}/{num_samples}")
try:
# Handle different dataset formats
if "context" in example:
context = example["context"]
elif "text" in example:
context = example["text"]
else:
print(f"Warning: No context found in sample {i}")
continue
if "question" in example:
question = example["question"]
elif "title" in example:
question = example["title"]
else:
print(f"Warning: No question found in sample {i}")
continue
# Handle answers field
ground_truths = []
if "answers" in example:
answers = example["answers"]
if isinstance(answers, dict):
if "text" in answers:
if isinstance(answers["text"], list):
ground_truths = [ans for ans in answers["text"] if ans.strip()]
else:
ground_truths = [answers["text"]] if answers["text"].strip() else []
elif isinstance(answers, list):
ground_truths = answers
# Skip if no ground truth
if not ground_truths:
print(f"Warning: No ground truth found for sample {i}")
continue
# Get model prediction
try:
result = qa_pipeline(question=question, context=context)
predicted_answer = result["answer"]
confidence = result["score"]
except Exception as e:
print(f"Error getting prediction for sample {i}: {e}")
continue
# Calculate metrics using max over ground truths
em = max_over_ground_truths(exact_match_score, predicted_answer, ground_truths)
f1 = max_over_ground_truths(f1_score_qa, predicted_answer, ground_truths)
exact_matches.append(em)
f1_scores.append(f1)
predictions.append({
"Sample_ID": i+1,
"Question": question[:100] + "..." if len(question) > 100 else question,
"Predicted_Answer": predicted_answer[:100] + "..." if len(predicted_answer) > 100 else predicted_answer,
"Ground_Truth": ground_truths[0][:100] + "..." if len(ground_truths[0]) > 100 else ground_truths[0],
"Num_Ground_Truths": len(ground_truths),
"Exact_Match": em,
"F1_Score": round(f1, 3),
"Confidence": round(confidence, 3)
})
except Exception as e:
print(f"Error processing sample {i}: {e}")
continue
progress(0.9, desc="Calculating final metrics...")
# Calculate final metrics
if len(exact_matches) == 0:
return "❌ No samples were successfully processed", pd.DataFrame(), None
avg_exact_match = np.mean(exact_matches) * 100
avg_f1_score = np.mean(f1_scores) * 100
# Calculate additional statistics
high_confidence_samples = [p for p in predictions if p['Confidence'] > 0.8]
perfect_matches = [p for p in predictions if p['Exact_Match'] == 1]
high_f1_samples = [p for p in predictions if p['F1_Score'] > 0.8]
# Create results summary
results_summary = f"""
# πŸ“Š CUAD Model Evaluation Results
## 🎯 Overall Performance
- **Model**: AvocadoMuffin/roberta-cuad-qa-v3
- **Dataset**: CUAD (Contract Understanding Atticus Dataset)
- **Samples Evaluated**: {len(exact_matches)}
- **Evaluation Date**: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
## πŸ“ˆ Core Metrics
- **Exact Match Score**: {avg_exact_match:.2f}%
- **F1 Score**: {avg_f1_score:.2f}%
## πŸ” Performance Analysis
- **High Confidence Predictions (>0.8)**: {len(high_confidence_samples)} ({len(high_confidence_samples)/len(predictions)*100:.1f}%)
- **Perfect Matches**: {len(perfect_matches)} ({len(perfect_matches)/len(predictions)*100:.1f}%)
- **High F1 Scores (>0.8)**: {len(high_f1_samples)} ({len(high_f1_samples)/len(predictions)*100:.1f}%)
## πŸ“Š Distribution
- **Average Confidence**: {np.mean([p['Confidence'] for p in predictions]):.3f}
- **Median F1 Score**: {np.median([p['F1_Score'] for p in predictions]):.3f}
- **Samples with Multiple Ground Truths**: {len([p for p in predictions if p['Num_Ground_Truths'] > 1])}
## 🎯 Evaluation Quality
The evaluation accounts for multiple ground truth answers where available, using the maximum score across all valid answers for each question.
"""
# Create detailed results DataFrame
df = pd.DataFrame(predictions)
# Save results to file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"cuad_evaluation_results_{timestamp}.json"
detailed_results = {
"model_name": "AvocadoMuffin/roberta-cuad-qa-v3",
"dataset": "cuad",
"num_samples": len(exact_matches),
"exact_match_score": avg_exact_match,
"f1_score": avg_f1_score,
"evaluation_date": datetime.now().isoformat(),
"evaluation_methodology": "max_over_ground_truths",
"predictions": predictions,
"summary_stats": {
"avg_confidence": float(np.mean([p['Confidence'] for p in predictions])),
"median_f1": float(np.median([p['F1_Score'] for p in predictions])),
"samples_with_multiple_ground_truths": len([p for p in predictions if p['Num_Ground_Truths'] > 1])
}
}
try:
with open(results_file, "w") as f:
json.dump(detailed_results, f, indent=2)
print(f"βœ“ Results saved to {results_file}")
except Exception as e:
print(f"⚠ Warning: Could not save results file: {e}")
results_file = None
progress(1.0, desc="βœ… Evaluation completed!")
return results_summary, df, results_file
def create_gradio_interface():
"""Create Gradio interface for CUAD evaluation"""
with gr.Blocks(title="CUAD Model Evaluator", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>πŸ›οΈ CUAD Model Evaluation Dashboard</h1>
<p>Evaluate your CUAD (Contract Understanding Atticus Dataset) Question Answering model</p>
<p><strong>Model:</strong> AvocadoMuffin/roberta-cuad-qa-v3</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>βš™οΈ Evaluation Settings</h3>")
num_samples = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Number of samples to evaluate",
info="Choose between 10-500 samples (more samples = more accurate but slower)"
)
evaluate_btn = gr.Button(
"πŸš€ Start Evaluation",
variant="primary",
size="lg"
)
gr.HTML("""
<div style="margin-top: 20px; padding: 15px; background-color: #f0f0f0; border-radius: 8px;">
<h4>πŸ“‹ What this evaluates:</h4>
<ul>
<li><strong>Exact Match</strong>: Percentage of perfect predictions</li>
<li><strong>F1 Score</strong>: Token-level overlap between prediction and ground truth</li>
<li><strong>Confidence</strong>: Model's confidence in its predictions</li>
<li><strong>Max-over-GT</strong>: Best score across multiple ground truth answers</li>
</ul>
</div>
""")
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Results</h3>")
results_summary = gr.Markdown(
value="Click 'πŸš€ Start Evaluation' to begin...",
label="Evaluation Summary"
)
gr.HTML("<hr>")
with gr.Row():
gr.HTML("<h3>πŸ“‹ Detailed Results</h3>")
with gr.Row():
detailed_results = gr.Dataframe(
label="Sample-by-Sample Results",
interactive=False,
wrap=True
)
with gr.Row():
download_file = gr.File(
label="πŸ“₯ Download Complete Results (JSON)",
visible=False
)
# Event handlers
def handle_evaluation(num_samples):
summary, df, file_path = run_evaluation(num_samples)
if file_path and os.path.exists(file_path):
return summary, df, gr.update(visible=True, value=file_path)
else:
return summary, df, gr.update(visible=False)
evaluate_btn.click(
fn=handle_evaluation,
inputs=[num_samples],
outputs=[results_summary, detailed_results, download_file],
show_progress=True
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 30px; padding: 20px; color: #666;">
<p>πŸ€– Powered by Hugging Face Transformers & Gradio</p>
<p>πŸ“š CUAD Dataset by The Atticus Project</p>
</div>
""")
return demo
if __name__ == "__main__":
print("CUAD Model Evaluation with Gradio Interface")
print("=" * 50)
# Check if CUDA is available
if torch.cuda.is_available():
print(f"βœ“ CUDA available: {torch.cuda.get_device_name(0)}")
else:
print("! Running on CPU")
# Create and launch Gradio interface
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True
)