|
|
|
from batch_sampler import BatchSampler |
|
from image_dataset import ImageDataset |
|
from net import Net, ResNetModel, EfficientNetModel, EfficientNetModel_b7 |
|
from train_test import train_model, test_model |
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay,classification_report |
|
from visualise_performance_metrics import create_confusion_matrix, ROC_multiclass |
|
from image_dataset_BINARY import ImageDatasetBINARY |
|
from net_BINARY import Net_BINARY |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torchsummary import summary |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.pyplot import figure |
|
import os |
|
import argparse |
|
import plotext |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_curve |
|
import numpy as np |
|
|
|
def main(args: argparse.Namespace, activeloop: bool = True) -> None: |
|
|
|
|
|
|
|
train_dataset = ImageDataset(Path('dc1/data/X_train.npy'), Path('dc1/data/Y_train.npy')) |
|
test_dataset = ImageDataset(Path('dc1/data/X_test.npy'), Path('dc1/data/Y_test.npy')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = EfficientNetModel(n_classes=6) |
|
|
|
|
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1) |
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1) |
|
loss_function = nn.CrossEntropyLoss() |
|
|
|
|
|
n_epochs = args.nb_epochs |
|
batch_size = args.batch_size |
|
|
|
|
|
|
|
|
|
|
|
DEBUG = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and not DEBUG: |
|
print("@@@ CUDA device found, enabling CUDA training...") |
|
device = "cuda" |
|
model.to(device) |
|
|
|
summary(model, (1, 128, 128), device=device) |
|
elif ( |
|
torch.backends.mps.is_available() and not DEBUG |
|
): |
|
print("@@@ Apple silicon device enabled, training with Metal backend...") |
|
device = "mps" |
|
model.to(device) |
|
else: |
|
print("@@@ No GPU boosting device found, training on CPU...") |
|
device = "cpu" |
|
|
|
summary(model, (1, 128, 128), device=device) |
|
|
|
|
|
train_sampler = BatchSampler( |
|
batch_size=batch_size, dataset=train_dataset, balanced=args.balanced_batches |
|
) |
|
test_sampler = BatchSampler( |
|
batch_size=100, dataset=test_dataset, balanced=args.balanced_batches |
|
) |
|
|
|
mean_losses_train: List[torch.Tensor] = [] |
|
mean_losses_test: List[torch.Tensor] = [] |
|
|
|
for e in range(n_epochs): |
|
if activeloop: |
|
|
|
losses = train_model(model, train_sampler, optimizer, loss_function, device) |
|
|
|
mean_loss = sum(losses) / len(losses) |
|
mean_losses_train.append(mean_loss) |
|
print(f"\nEpoch {e + 1} training done, loss on train set: {mean_loss}\n") |
|
|
|
|
|
|
|
fpr = {x:[] for x in range(6)} |
|
tpr = {x:[] for x in range(6)} |
|
auc = {} |
|
|
|
|
|
losses, y_pred_probs = test_model(model, test_sampler, loss_function, device, fpr, tpr, auc) |
|
|
|
|
|
mean_loss = sum(losses) / len(losses) |
|
mean_losses_test.append(mean_loss) |
|
print(f"\nEpoch {e + 1} testing done, loss on test set: {mean_loss}\n") |
|
|
|
print(auc) |
|
|
|
|
|
plotext.clf() |
|
plotext.scatter(mean_losses_train, label="train") |
|
plotext.scatter(mean_losses_test, label="test") |
|
plotext.title("Train and test loss") |
|
|
|
plotext.xticks([i for i in range(len(mean_losses_train) + 1)]) |
|
|
|
plotext.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
|
colors = plt.cm.get_cmap('viridis', 6).colors |
|
class_names = ['Class 0 (Atelactasis)','Class 1 (Effusion)', 'Class 2 (Infiltration)', 'Class 3 (No Finding)', 'Class 4 (Nodule)', 'Class 5 (Pneumonia)'] |
|
|
|
for i, color in zip(range(6), colors): |
|
plt.plot(fpr[i], tpr[i], color=color, lw=2, label='{} (AUC = {:.2f})'.format(class_names[i], auc[i])) |
|
|
|
plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--') |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('ROC Curves for 6 Classes') |
|
plt.legend(loc="lower right") |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
now = datetime.now() |
|
|
|
if not Path("model_weights/").exists(): |
|
os.mkdir(Path("model_weights/")) |
|
if not Path("model_weights/").exists(): |
|
os.mkdir(Path("model_weights/")) |
|
|
|
|
|
torch.save(model.state_dict(), f"model_weights/model_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.txt") |
|
torch.save(model.state_dict(), f"model_weights/model_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.txt") |
|
|
|
|
|
figure(figsize=(9, 10), dpi=80) |
|
fig, (ax1, ax2) = plt.subplots(2, sharex=True) |
|
|
|
ax1.plot(range(1, 1 + n_epochs), [x.detach().cpu() for x in mean_losses_train], label="Train", color="blue") |
|
ax2.plot(range(1, 1 + n_epochs), [x.detach().cpu() for x in mean_losses_test], label="Test", color="red") |
|
fig.legend() |
|
|
|
|
|
|
|
if not Path("artifacts/").exists(): |
|
os.mkdir(Path("artifacts/")) |
|
if not Path("artifacts/").exists(): |
|
os.mkdir(Path("artifacts/")) |
|
|
|
|
|
fig.savefig(Path("artifacts") / f"session_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.png") |
|
|
|
|
|
|
|
|
|
true_labels = test_dataset.get_labels() |
|
|
|
|
|
model.eval() |
|
|
|
predicted_labels = [] |
|
with torch.no_grad(): |
|
for inputs, _ in test_dataset: |
|
inputs = inputs.unsqueeze(0).to(device) |
|
|
|
outputs = model(inputs) |
|
|
|
|
|
_, predicted = torch.max(outputs, 1) |
|
predicted_labels.extend(predicted.cpu().numpy()) |
|
|
|
|
|
conf_matrix = confusion_matrix(true_labels, predicted_labels) |
|
|
|
print("Confusion Matrix:") |
|
print(conf_matrix) |
|
|
|
|
|
|
|
|
|
|
|
create_confusion_matrix(true_labels, predicted_labels) |
|
|
|
|
|
class_report = classification_report(true_labels, predicted_labels) |
|
print("\nClassification Report:") |
|
print(class_report) |
|
|
|
fig.savefig(Path("artifacts") / f"session_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.png") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--nb_epochs", help="number of training iterations", default=1, type=int) |
|
parser.add_argument("--batch_size", help="batch_size", default=25, type=int) |
|
parser.add_argument( |
|
"--balanced_batches", |
|
help="whether to balance batches for class labels", |
|
default=True, |
|
type=bool, |
|
) |
|
args = parser.parse_args() |
|
|
|
main(args) |