Caleb Spradlin
initial commit
ab687e7
raw
history blame
6.74 kB
from pl_bolts.models.vision.unet import UNet
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
import torch
from torch.nn import functional as F
from torchmetrics import MetricCollection, Accuracy, IoU
# -------------------------------------------------------------------------------
# class UNet
# This class performs training and classification of satellite imagery using a
# UNet CNN.
# -------------------------------------------------------------------------------
@MODEL_REGISTRY
class UNetSegmentation(LightningModule):
# ---------------------------------------------------------------------------
# __init__
# ---------------------------------------------------------------------------
def __init__(
self,
input_channels: int = 4,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
):
super().__init__()
self.input_channels = input_channels
self.num_classes = num_classes
self.num_layers = num_layers
self.features_start = features_start
self.bilinear = bilinear
self.net = UNet(
input_channels=self.input_channels,
num_classes=num_classes,
num_layers=self.num_layers,
features_start=self.features_start,
bilinear=self.bilinear,
)
metrics = MetricCollection(
[
Accuracy(), IoU(num_classes=self.num_classes)
]
)
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
# ---------------------------------------------------------------------------
# model methods
# ---------------------------------------------------------------------------
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_nb):
img, mask = batch
img, mask = img.float(), mask.long()
# Forward step, calculate logits and loss
logits = self(img)
# loss_val = F.cross_entropy(logits, mask)
# Get target tensor from logits for metrics, calculate metrics
probs = torch.nn.functional.softmax(logits, dim=1)
probs = torch.argmax(probs, dim=1)
# metrics_train = self.train_metrics(probs, mask)
# log_dict = {"train_loss": loss_val.detach()}
# return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict}
# return {
# "loss": loss_val, "train_acc": metrics_train['train_Accuracy'],
# "train_iou": metrics_train['train_IoU']
# }
tensorboard_logs = self.train_metrics(probs, mask)
tensorboard_logs['loss'] = F.cross_entropy(logits, mask)
# tensorboard_logs['lr'] = self._get_current_lr()
self.log(
'acc', tensorboard_logs['train_Accuracy'],
sync_dist=True, prog_bar=True
)
self.log(
'iou', tensorboard_logs['train_IoU'],
sync_dist=True, prog_bar=True
)
return tensorboard_logs
def training_epoch_end(self, outputs):
pass
# Get average metrics from multi-GPU batch sources
# loss_val = torch.stack([x["loss"] for x in outputs]).mean()
# acc_train = torch.stack([x["train_acc"] for x in outputs]).mean()
# iou_train = torch.stack([x["train_iou"] for x in outputs]).mean()
# tensorboard_logs = self.train_metrics(probs, mask)
# tensorboard_logs['loss'] = F.cross_entropy(logits, mask)
# tensorboard_logs['lr'] = self._get_current_lr()
# self.log(
# 'acc', tensorboard_logs['train_Accuracy'],
# sync_dist=True, prog_bar=True
# )
# self.log(
# 'iou', tensorboard_logs['train_IoU'],
# sync_dist=True, prog_bar=True
# )
# # Send output to logger
# self.log(
# "loss", loss_val, on_epoch=True, prog_bar=True, logger=True)
# self.log(
# "train_acc", acc_train,
# on_epoch=True, prog_bar=True, logger=True)
# self.log(
# "train_iou", iou_train,
# on_epoch=True, prog_bar=True, logger=True)
# return tensorboard_logs
def validation_step(self, batch, batch_idx):
# Get data, change type for validation
img, mask = batch
img, mask = img.float(), mask.long()
# Forward step, calculate logits and loss
logits = self(img)
# loss_val = F.cross_entropy(logits, mask)
# Get target tensor from logits for metrics, calculate metrics
probs = torch.nn.functional.softmax(logits, dim=1)
probs = torch.argmax(probs, dim=1)
# metrics_val = self.val_metrics(probs, mask)
# return {
# "val_loss": loss_val, "val_acc": metrics_val['val_Accuracy'],
# "val_iou": metrics_val['val_IoU']
# }
tensorboard_logs = self.val_metrics(probs, mask)
tensorboard_logs['val_loss'] = F.cross_entropy(logits, mask)
self.log(
'val_loss', tensorboard_logs['val_loss'],
sync_dist=True, prog_bar=True
)
self.log(
'val_acc', tensorboard_logs['val_Accuracy'],
sync_dist=True, prog_bar=True
)
self.log(
'val_iou', tensorboard_logs['val_IoU'],
sync_dist=True, prog_bar=True
)
return tensorboard_logs
# def validation_epoch_end(self, outputs):
# # Get average metrics from multi-GPU batch sources
# loss_val = torch.stack([x["val_loss"] for x in outputs]).mean()
# acc_val = torch.stack([x["val_acc"] for x in outputs]).mean()
# iou_val = torch.stack([x["val_iou"] for x in outputs]).mean()
# # Send output to logger
# self.log(
# "val_loss", torch.mean(self.all_gather(loss_val)),
# on_epoch=True, prog_bar=True, logger=True)
# self.log(
# "val_acc", torch.mean(self.all_gather(acc_val)),
# on_epoch=True, prog_bar=True, logger=True)
# self.log(
# "val_iou", torch.mean(self.all_gather(iou_val)),
# on_epoch=True, prog_bar=True, logger=True)
# def configure_optimizers(self):
# opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
# sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
# return [opt], [sch]
def test_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)