|
from tqdm import tqdm |
|
import torch |
|
from net import Net |
|
from batch_sampler import BatchSampler |
|
from torch.nn import functional as F |
|
import numpy as np |
|
from net import Net |
|
from batch_sampler import BatchSampler |
|
from torch.nn import functional as F |
|
import numpy as np |
|
import torch.nn as nn |
|
|
|
from net import Net, ResNetModel, EfficientNetModel, EfficientNetModel_b7 |
|
from batch_sampler import BatchSampler |
|
from image_dataset import ImageDataset |
|
|
|
from typing import Callable, List, Tuple |
|
|
|
from sklearn.metrics import roc_curve, auc |
|
from sklearn.preprocessing import label_binarize |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
def train_model( |
|
|
|
model: Net, |
|
train_sampler: BatchSampler, |
|
optimizer: torch.optim.Optimizer, |
|
loss_function: Callable[..., torch.Tensor], |
|
device: str, |
|
) -> List[torch.Tensor]: |
|
|
|
losses = [] |
|
|
|
model.train() |
|
|
|
for batch in tqdm(train_sampler): |
|
|
|
x, y = batch |
|
|
|
x, y = x.to(device), y.to(device) |
|
|
|
predictions = model.forward(x) |
|
loss = loss_function(predictions, y) |
|
losses.append(loss) |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
return losses |
|
|
|
|
|
def test_model( |
|
model: Net, |
|
test_sampler: BatchSampler, |
|
loss_function: Callable[..., torch.Tensor], |
|
device: str, |
|
fpr, |
|
tpr, |
|
roc |
|
) -> Tuple[List[torch.Tensor], List[np.ndarray]]: |
|
|
|
model.eval() |
|
losses = [] |
|
all_y_pred_probs = [] |
|
all_y_true = [] |
|
|
|
|
|
with torch.no_grad(): |
|
for (x, y) in tqdm(test_sampler): |
|
|
|
x = x.to(device) |
|
y = y.to(device) |
|
prediction = model.forward(x) |
|
loss = loss_function(prediction, y) |
|
losses.append(loss) |
|
probabilities = F.softmax(prediction, dim=1) |
|
all_y_pred_probs.append(probabilities.cpu().numpy()) |
|
all_y_true.extend(y.cpu().numpy()) |
|
|
|
y_pred_probs = np.concatenate(all_y_pred_probs, axis=0) |
|
y_true = np.array(all_y_true) |
|
|
|
|
|
|
|
for i in range(6): |
|
a, b, _ = roc_curve(y_true == i, y_pred_probs[:, i]) |
|
fpr[i].extend(a) |
|
tpr[i].extend(b) |
|
roc[i] = auc(fpr[i], tpr[i]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return losses, y_pred_probs |