from pathlib import Path import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay,classification_report from train_test import test_model import torch from net import Net from batch_sampler import BatchSampler from image_dataset import ImageDataset from sklearn.metrics import roc_curve, auc, RocCurveDisplay from sklearn.preprocessing import label_binarize from itertools import cycle # from scipy import interp import numpy as np from sklearn.metrics import roc_auc_score from sklearn.preprocessing import LabelBinarizer def create_confusion_matrix(true_labels, predicted_labels): cm = confusion_matrix(true_labels, predicted_labels) # Display it as a heatmap disp = ConfusionMatrixDisplay(confusion_matrix=cm) disp.plot(cmap=plt.cm.Blues) plt.title('Confusion Matrix') plt.show()