from pl_bolts.models.vision.unet import UNet | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning.utilities.cli import MODEL_REGISTRY | |
import torch | |
from torch.nn import functional as F | |
from torchmetrics import MetricCollection, Accuracy, IoU | |
# ------------------------------------------------------------------------------- | |
# class UNet | |
# This class performs training and classification of satellite imagery using a | |
# UNet CNN. | |
# ------------------------------------------------------------------------------- | |
class UNetSegmentation(LightningModule): | |
# --------------------------------------------------------------------------- | |
# __init__ | |
# --------------------------------------------------------------------------- | |
def __init__( | |
self, | |
input_channels: int = 4, | |
num_classes: int = 19, | |
num_layers: int = 5, | |
features_start: int = 64, | |
bilinear: bool = False, | |
): | |
super().__init__() | |
self.input_channels = input_channels | |
self.num_classes = num_classes | |
self.num_layers = num_layers | |
self.features_start = features_start | |
self.bilinear = bilinear | |
self.net = UNet( | |
input_channels=self.input_channels, | |
num_classes=num_classes, | |
num_layers=self.num_layers, | |
features_start=self.features_start, | |
bilinear=self.bilinear, | |
) | |
metrics = MetricCollection( | |
[ | |
Accuracy(), IoU(num_classes=self.num_classes) | |
] | |
) | |
self.train_metrics = metrics.clone(prefix='train_') | |
self.val_metrics = metrics.clone(prefix='val_') | |
# --------------------------------------------------------------------------- | |
# model methods | |
# --------------------------------------------------------------------------- | |
def forward(self, x): | |
return self.net(x) | |
def training_step(self, batch, batch_nb): | |
img, mask = batch | |
img, mask = img.float(), mask.long() | |
# Forward step, calculate logits and loss | |
logits = self(img) | |
# loss_val = F.cross_entropy(logits, mask) | |
# Get target tensor from logits for metrics, calculate metrics | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
probs = torch.argmax(probs, dim=1) | |
# metrics_train = self.train_metrics(probs, mask) | |
# log_dict = {"train_loss": loss_val.detach()} | |
# return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict} | |
# return { | |
# "loss": loss_val, "train_acc": metrics_train['train_Accuracy'], | |
# "train_iou": metrics_train['train_IoU'] | |
# } | |
tensorboard_logs = self.train_metrics(probs, mask) | |
tensorboard_logs['loss'] = F.cross_entropy(logits, mask) | |
# tensorboard_logs['lr'] = self._get_current_lr() | |
self.log( | |
'acc', tensorboard_logs['train_Accuracy'], | |
sync_dist=True, prog_bar=True | |
) | |
self.log( | |
'iou', tensorboard_logs['train_IoU'], | |
sync_dist=True, prog_bar=True | |
) | |
return tensorboard_logs | |
def training_epoch_end(self, outputs): | |
pass | |
# Get average metrics from multi-GPU batch sources | |
# loss_val = torch.stack([x["loss"] for x in outputs]).mean() | |
# acc_train = torch.stack([x["train_acc"] for x in outputs]).mean() | |
# iou_train = torch.stack([x["train_iou"] for x in outputs]).mean() | |
# tensorboard_logs = self.train_metrics(probs, mask) | |
# tensorboard_logs['loss'] = F.cross_entropy(logits, mask) | |
# tensorboard_logs['lr'] = self._get_current_lr() | |
# self.log( | |
# 'acc', tensorboard_logs['train_Accuracy'], | |
# sync_dist=True, prog_bar=True | |
# ) | |
# self.log( | |
# 'iou', tensorboard_logs['train_IoU'], | |
# sync_dist=True, prog_bar=True | |
# ) | |
# # Send output to logger | |
# self.log( | |
# "loss", loss_val, on_epoch=True, prog_bar=True, logger=True) | |
# self.log( | |
# "train_acc", acc_train, | |
# on_epoch=True, prog_bar=True, logger=True) | |
# self.log( | |
# "train_iou", iou_train, | |
# on_epoch=True, prog_bar=True, logger=True) | |
# return tensorboard_logs | |
def validation_step(self, batch, batch_idx): | |
# Get data, change type for validation | |
img, mask = batch | |
img, mask = img.float(), mask.long() | |
# Forward step, calculate logits and loss | |
logits = self(img) | |
# loss_val = F.cross_entropy(logits, mask) | |
# Get target tensor from logits for metrics, calculate metrics | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
probs = torch.argmax(probs, dim=1) | |
# metrics_val = self.val_metrics(probs, mask) | |
# return { | |
# "val_loss": loss_val, "val_acc": metrics_val['val_Accuracy'], | |
# "val_iou": metrics_val['val_IoU'] | |
# } | |
tensorboard_logs = self.val_metrics(probs, mask) | |
tensorboard_logs['val_loss'] = F.cross_entropy(logits, mask) | |
self.log( | |
'val_loss', tensorboard_logs['val_loss'], | |
sync_dist=True, prog_bar=True | |
) | |
self.log( | |
'val_acc', tensorboard_logs['val_Accuracy'], | |
sync_dist=True, prog_bar=True | |
) | |
self.log( | |
'val_iou', tensorboard_logs['val_IoU'], | |
sync_dist=True, prog_bar=True | |
) | |
return tensorboard_logs | |
# def validation_epoch_end(self, outputs): | |
# # Get average metrics from multi-GPU batch sources | |
# loss_val = torch.stack([x["val_loss"] for x in outputs]).mean() | |
# acc_val = torch.stack([x["val_acc"] for x in outputs]).mean() | |
# iou_val = torch.stack([x["val_iou"] for x in outputs]).mean() | |
# # Send output to logger | |
# self.log( | |
# "val_loss", torch.mean(self.all_gather(loss_val)), | |
# on_epoch=True, prog_bar=True, logger=True) | |
# self.log( | |
# "val_acc", torch.mean(self.all_gather(acc_val)), | |
# on_epoch=True, prog_bar=True, logger=True) | |
# self.log( | |
# "val_iou", torch.mean(self.all_gather(iou_val)), | |
# on_epoch=True, prog_bar=True, logger=True) | |
# def configure_optimizers(self): | |
# opt = torch.optim.Adam(self.net.parameters(), lr=self.lr) | |
# sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) | |
# return [opt], [sch] | |
def test_step(self, batch, batch_idx, dataloader_idx=0): | |
return self(batch) | |
def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
return self(batch) | |