|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
from sklearn import metrics |
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc, classification_report, RocCurveDisplay |
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay |
|
from train_test import test_model |
|
import torch |
|
from net import Net |
|
from batch_sampler import BatchSampler |
|
from image_dataset import ImageDataset |
|
from sklearn.metrics import roc_curve, auc, RocCurveDisplay |
|
from sklearn.preprocessing import label_binarize |
|
from itertools import cycle |
|
from scipy import interp |
|
import numpy as np |
|
from sklearn.metrics import roc_auc_score |
|
from sklearn.preprocessing import LabelBinarizer |
|
import seaborn as sns |
|
|
|
|
|
|
|
def ConfusionMatrix(y_pred, y): |
|
|
|
y_pred = y_pred.cpu() |
|
y = y.cpu() |
|
reshaped = y.reshape(-1) |
|
|
|
|
|
report = classification_report(y, y_pred, zero_division=1) |
|
print(report) |
|
conf = confusion_matrix(reshaped, y_pred) |
|
disp = ConfusionMatrixDisplay(confusion_matrix=conf) |
|
|
|
FP = conf.sum(axis=0) - np.diag(conf) |
|
FN = conf.sum(axis=1) - np.diag(conf) |
|
TP = np.diag(conf) |
|
TN = conf.sum() - (FP + FN + TP) |
|
|
|
return disp, FP, FN, TP, TN |
|
|
|
|
|
|
|
def ROC(y_pred_prob, y_pred, y): |
|
prob_reshape = y_pred_prob.cpu().reshape(-1) |
|
y_pred = y_pred.cpu() |
|
reshaped = y.cpu().reshape(-1) |
|
y_pred_prob = y_pred_prob.cpu().numpy() |
|
y_pred = y_pred.cpu().numpy() |
|
y = y.cpu().numpy() |
|
|
|
|
|
|
|
binary = [] |
|
for i in range(len(y_pred)): |
|
if (y_pred[i] == reshaped[i]): |
|
binary.append(1) |
|
else: |
|
binary.append(0) |
|
fpr, tpr, threshold = metrics.roc_curve(binary, prob_reshape[:86]) |
|
roc_auc = metrics.auc(fpr, tpr) |
|
disp_roc = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc) |
|
|
|
return disp_roc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ROC_multiclass(y_true, y_pred_prob, n_classes): |
|
|
|
y_true = label_binarize(y_true, classes=[*range(n_classes)]) |
|
fpr = dict() |
|
tpr = dict() |
|
roc_auc = dict() |
|
|
|
for i in range(n_classes): |
|
fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_pred_prob[:, i]) |
|
roc_auc[i] = auc(fpr[i], tpr[i]) |
|
|
|
|
|
plt.figure() |
|
colors = cycle(['blue', 'red', 'green', 'yellow', 'orange', 'purple']) |
|
for i, color in zip(range(n_classes), colors): |
|
plt.plot(fpr[i], tpr[i], color=color, lw=2, |
|
label='ROC curve of class {0} (area = {1:0.2f})' |
|
''.format(i, roc_auc[i])) |
|
|
|
plt.plot([0, 1], [0, 1], 'k--', lw=2) |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('Multiclass ROC') |
|
plt.legend(loc="lower right") |
|
plt.show() |
|
|
|
return plt |
|
|