Deeptanshuu's picture
Upload folder using huggingface_hub
d187b57 verified
raw
history blame
6.13 kB
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import os
import json
from pathlib import Path
def plot_roc_curves(predictions_path, output_dir=None):
"""
Plot ROC curves from model predictions
Args:
predictions_path (str): Path to the .npz file containing predictions
output_dir (str, optional): Directory to save plots. If None, will use same directory as predictions
"""
# Load predictions
data = np.load(predictions_path)
predictions = data['predictions']
labels = data['labels']
langs = data['langs']
# Create output directory
if output_dir is None:
output_dir = os.path.dirname(predictions_path)
plots_dir = os.path.join(output_dir, 'plots')
os.makedirs(plots_dir, exist_ok=True)
# Define toxicity types
toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# Define language mapping
id_to_lang = {
0: 'English (en)',
1: 'Russian (ru)',
2: 'Turkish (tr)',
3: 'Spanish (es)',
4: 'French (fr)',
5: 'Italian (it)',
6: 'Portuguese (pt)'
}
# Plot overall ROC curves (one per class)
plt.figure(figsize=(10, 8))
for i, class_name in enumerate(toxicity_types):
fpr, tpr, _ = roc_curve(labels[:, i], predictions[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves - All Classes')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, 'roc_all_classes.png'), dpi=300, bbox_inches='tight')
plt.close()
# Plot per-class ROC curves with confidence intervals
n_bootstrap = 1000
n_classes = len(toxicity_types)
for i, class_name in enumerate(toxicity_types):
plt.figure(figsize=(8, 6))
# Calculate main ROC curve
fpr, tpr, _ = roc_curve(labels[:, i], predictions[:, i])
roc_auc = auc(fpr, tpr)
# Plot main curve
plt.plot(fpr, tpr, 'b-', label=f'ROC (AUC = {roc_auc:.3f})')
# Bootstrap for confidence intervals
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
for _ in range(n_bootstrap):
# Bootstrap sample indices
indices = np.random.randint(0, len(labels), len(labels))
if len(np.unique(labels[indices, i])) < 2:
continue
# Calculate ROC curve
fpr, tpr, _ = roc_curve(labels[indices, i], predictions[indices, i])
# Interpolate TPR at mean FPR points
interp_tpr = np.interp(mean_fpr, fpr, tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(auc(fpr, tpr))
# Calculate confidence intervals
tprs = np.array(tprs)
mean_tpr = np.mean(tprs, axis=0)
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
# Plot confidence interval
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=f'±1 std. dev.')
# Calculate AUC confidence interval
auc_mean = np.mean(aucs)
auc_std = np.std(aucs)
plt.plot([], [], ' ', label=f'AUC = {auc_mean:.3f} ± {auc_std:.3f}')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'ROC Curve - {class_name}')
plt.legend(loc='lower right')
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, f'roc_{class_name}.png'), dpi=300)
plt.close()
# Plot per-language ROC curves (for toxic class)
plt.figure(figsize=(10, 8))
for lang_id, lang_name in id_to_lang.items():
# Get samples for this language
lang_mask = langs == lang_id
if lang_mask.sum() > 0 and len(np.unique(labels[lang_mask, 0])) > 1:
fpr, tpr, _ = roc_curve(labels[lang_mask, 0], predictions[lang_mask, 0])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{lang_name} (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves by Language - Toxic Class')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, 'roc_by_language.png'), dpi=300, bbox_inches='tight')
plt.close()
print(f"\nROC curves have been saved to {plots_dir}")
print("\nGenerated plots:")
print("1. roc_all_classes.png - ROC curves for all toxicity classes")
print("2. roc_[class_name].png - Individual ROC curves with confidence intervals for each class")
print("3. roc_by_language.png - ROC curves for each language (toxic class)")
if __name__ == '__main__':
# Use the latest evaluation results
eval_dir = 'evaluation_results'
if os.path.exists(eval_dir):
# Find most recent evaluation directory
eval_dirs = sorted([d for d in os.listdir(eval_dir) if d.startswith('eval_')], reverse=True)
if eval_dirs:
latest_eval = os.path.join(eval_dir, eval_dirs[0])
predictions_path = os.path.join(latest_eval, 'predictions.npz')
if os.path.exists(predictions_path):
plot_roc_curves(predictions_path)
else:
print(f"No predictions file found in {latest_eval}")
else:
print(f"No evaluation directories found in {eval_dir}")
else:
print(f"Evaluation directory {eval_dir} not found")