|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.metrics import roc_curve, auc |
|
import os |
|
import json |
|
from pathlib import Path |
|
|
|
def plot_roc_curves(predictions_path, output_dir=None): |
|
""" |
|
Plot ROC curves from model predictions |
|
|
|
Args: |
|
predictions_path (str): Path to the .npz file containing predictions |
|
output_dir (str, optional): Directory to save plots. If None, will use same directory as predictions |
|
""" |
|
|
|
data = np.load(predictions_path) |
|
predictions = data['predictions'] |
|
labels = data['labels'] |
|
langs = data['langs'] |
|
|
|
|
|
if output_dir is None: |
|
output_dir = os.path.dirname(predictions_path) |
|
plots_dir = os.path.join(output_dir, 'plots') |
|
os.makedirs(plots_dir, exist_ok=True) |
|
|
|
|
|
toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
|
|
|
|
|
id_to_lang = { |
|
0: 'English (en)', |
|
1: 'Russian (ru)', |
|
2: 'Turkish (tr)', |
|
3: 'Spanish (es)', |
|
4: 'French (fr)', |
|
5: 'Italian (it)', |
|
6: 'Portuguese (pt)' |
|
} |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
for i, class_name in enumerate(toxicity_types): |
|
fpr, tpr, _ = roc_curve(labels[:, i], predictions[:, i]) |
|
roc_auc = auc(fpr, tpr) |
|
|
|
plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})') |
|
|
|
plt.plot([0, 1], [0, 1], 'k--', label='Random') |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('ROC Curves - All Classes') |
|
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
|
plt.grid(True) |
|
plt.tight_layout() |
|
plt.savefig(os.path.join(plots_dir, 'roc_all_classes.png'), dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
|
|
n_bootstrap = 1000 |
|
n_classes = len(toxicity_types) |
|
|
|
for i, class_name in enumerate(toxicity_types): |
|
plt.figure(figsize=(8, 6)) |
|
|
|
|
|
fpr, tpr, _ = roc_curve(labels[:, i], predictions[:, i]) |
|
roc_auc = auc(fpr, tpr) |
|
|
|
|
|
plt.plot(fpr, tpr, 'b-', label=f'ROC (AUC = {roc_auc:.3f})') |
|
|
|
|
|
tprs = [] |
|
aucs = [] |
|
mean_fpr = np.linspace(0, 1, 100) |
|
|
|
for _ in range(n_bootstrap): |
|
|
|
indices = np.random.randint(0, len(labels), len(labels)) |
|
if len(np.unique(labels[indices, i])) < 2: |
|
continue |
|
|
|
|
|
fpr, tpr, _ = roc_curve(labels[indices, i], predictions[indices, i]) |
|
|
|
|
|
interp_tpr = np.interp(mean_fpr, fpr, tpr) |
|
interp_tpr[0] = 0.0 |
|
tprs.append(interp_tpr) |
|
aucs.append(auc(fpr, tpr)) |
|
|
|
|
|
tprs = np.array(tprs) |
|
mean_tpr = np.mean(tprs, axis=0) |
|
std_tpr = np.std(tprs, axis=0) |
|
|
|
tprs_upper = np.minimum(mean_tpr + std_tpr, 1) |
|
tprs_lower = np.maximum(mean_tpr - std_tpr, 0) |
|
|
|
|
|
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, |
|
label=f'±1 std. dev.') |
|
|
|
|
|
auc_mean = np.mean(aucs) |
|
auc_std = np.std(aucs) |
|
plt.plot([], [], ' ', label=f'AUC = {auc_mean:.3f} ± {auc_std:.3f}') |
|
|
|
plt.plot([0, 1], [0, 1], 'k--', label='Random') |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title(f'ROC Curve - {class_name}') |
|
plt.legend(loc='lower right') |
|
plt.grid(True) |
|
plt.tight_layout() |
|
plt.savefig(os.path.join(plots_dir, f'roc_{class_name}.png'), dpi=300) |
|
plt.close() |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
for lang_id, lang_name in id_to_lang.items(): |
|
|
|
lang_mask = langs == lang_id |
|
if lang_mask.sum() > 0 and len(np.unique(labels[lang_mask, 0])) > 1: |
|
fpr, tpr, _ = roc_curve(labels[lang_mask, 0], predictions[lang_mask, 0]) |
|
roc_auc = auc(fpr, tpr) |
|
plt.plot(fpr, tpr, label=f'{lang_name} (AUC = {roc_auc:.3f})') |
|
|
|
plt.plot([0, 1], [0, 1], 'k--', label='Random') |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('ROC Curves by Language - Toxic Class') |
|
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
|
plt.grid(True) |
|
plt.tight_layout() |
|
plt.savefig(os.path.join(plots_dir, 'roc_by_language.png'), dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
print(f"\nROC curves have been saved to {plots_dir}") |
|
print("\nGenerated plots:") |
|
print("1. roc_all_classes.png - ROC curves for all toxicity classes") |
|
print("2. roc_[class_name].png - Individual ROC curves with confidence intervals for each class") |
|
print("3. roc_by_language.png - ROC curves for each language (toxic class)") |
|
|
|
if __name__ == '__main__': |
|
|
|
eval_dir = 'evaluation_results' |
|
if os.path.exists(eval_dir): |
|
|
|
eval_dirs = sorted([d for d in os.listdir(eval_dir) if d.startswith('eval_')], reverse=True) |
|
if eval_dirs: |
|
latest_eval = os.path.join(eval_dir, eval_dirs[0]) |
|
predictions_path = os.path.join(latest_eval, 'predictions.npz') |
|
if os.path.exists(predictions_path): |
|
plot_roc_curves(predictions_path) |
|
else: |
|
print(f"No predictions file found in {latest_eval}") |
|
else: |
|
print(f"No evaluation directories found in {eval_dir}") |
|
else: |
|
print(f"Evaluation directory {eval_dir} not found") |