Caleb Spradlin
initial commit
ab687e7
raw
history blame
11.1 kB
from argparse import ArgumentParser, Namespace
import multiprocessing
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from pytorch_caney.datasets.modis_dataset import MODISDataset
from pytorch_caney.utils import check_gpus_available
class UNet(nn.Module):
"""
Architecture based on U-Net: Convolutional Networks for
Biomedical Image Segmentation.
Link - https://arxiv.org/abs/1505.04597
>>> UNet(num_classes=2, num_layers=3) \
# doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
UNet(
(layers): ModuleList(
(0): DoubleConv(...)
(1): Down(...)
(2): Down(...)
(3): Up(...)
(4): Up(...)
(5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
"""
def __init__(
self,
num_channels: int = 7,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
super().__init__()
self.num_layers = num_layers
layers = [DoubleConv(num_channels, features_start)]
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2
for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
self.layers = nn.ModuleList(layers)
def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1: self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers: -1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])
class DoubleConv(nn.Module):
"""Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2.
>>> DoubleConv(4, 4) \
# doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
DoubleConv(
(net): Sequential(...)
)
"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class Down(nn.Module):
"""Combination of MaxPool2d and DoubleConv in series.
>>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Down(
(net): Sequential(
(0): MaxPool2d(
kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(net): Sequential(...)
)
)
)
"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))
def forward(self, x):
return self.net(x)
class Up(nn.Module):
"""Upsampling (by either bilinear interpolation or transpose convolutions)
followed by concatenation of feature
map from contracting path, followed by double 3x3 convolution.
>>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Up(
(upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(net): Sequential(...)
)
)
"""
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Sequential(
nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(
in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(
in_ch, in_ch // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.upsample(x1)
# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]
x1 = F.pad(
x1,
[
diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2
])
# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class SegmentationModel(LightningModule):
def __init__(
self,
data_path: list = [],
n_classes: int = 18,
batch_size: int = 256,
lr: float = 3e-4,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.data_paths = data_path
self.n_classes = n_classes
self.batch_size = batch_size
self.learning_rate = lr
self.num_layers = num_layers
self.features_start = features_start
self.bilinear = bilinear
self.validation_step_outputs = []
self.net = UNet(
num_classes=self.n_classes,
num_layers=self.num_layers,
features_start=self.features_start,
bilinear=self.bilinear
)
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.0173, 0.0332, 0.0088,
0.0136, 0.0381, 0.0348, 0.0249],
std=[0.0150, 0.0127, 0.0124,
0.0128, 0.0120, 0.0159, 0.0164]
),
]
)
print('> Init datasets')
self.trainset = MODISDataset(
self.data_paths, split="train", transform=self.transform)
self.validset = MODISDataset(
self.data_paths, split="valid", transform=self.transform)
print('Done init datasets')
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_nb):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss = F.cross_entropy(out, mask, ignore_index=250)
log_dict = {"train_loss": loss}
self.log_dict(log_dict)
return {"loss": loss, "log": log_dict, "progress_bar": log_dict}
def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
self.validation_step_outputs.append(loss_val)
return {"val_loss": loss_val}
def on_validation_epoch_end(self):
loss_val = torch.stack(self.validation_step_outputs).mean()
log_dict = {"val_loss": loss_val}
self.log("val_loss", loss_val, sync_dist=True)
self.validation_step_outputs.clear()
return {
"log": log_dict,
"val_loss": log_dict["val_loss"],
"progress_bar": log_dict
}
def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
# sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt] # , [sch]
def train_dataloader(self):
return DataLoader(
self.trainset,
batch_size=self.batch_size,
num_workers=multiprocessing.cpu_count(),
shuffle=True
)
def val_dataloader(self):
return DataLoader(
self.validset,
batch_size=self.batch_size,
num_workers=multiprocessing.cpu_count(),
shuffle=False
)
def main(hparams: Namespace):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
ngpus = int(hparams.ngpus)
# PT ligtning does not expect this, del after use
del hparams.ngpus
model = SegmentationModel(**vars(hparams))
# ------------------------
# 2 SET LOGGER
# ------------------------
# logger = True
# if hparams.log_wandb:
# logger = WandbLogger()
# # optional: log model topology
# logger.watch(model.net)
train_callbacks = [
# TQDMProgressBar(refresh_rate=20),
ModelCheckpoint(dirpath='models/',
monitor='val_loss',
save_top_k=5,
filename='{epoch}-{val_loss:.2f}.ckpt'),
EarlyStopping("val_loss", patience=10, mode='min'),
]
# See number of devices
check_gpus_available(ngpus)
# ------------------------
# 3 INIT TRAINER
# ------------------------
# trainer = Trainer(
# ------------------------
trainer = Trainer(
accelerator="gpu",
devices=ngpus,
strategy="ddp",
min_epochs=1,
max_epochs=500,
callbacks=train_callbacks,
logger=CSVLogger(save_dir="logs/"),
# precision=16 # makes loss nan, need to fix that
)
# ------------------------
# 5 START TRAINING
# ------------------------
trainer.fit(model)
trainer.save_checkpoint("best_model.ckpt")
# ------------------------
# 6 START TEST
# ------------------------
# test_set = MODISDataset(
# self.data_path, split=None, transform=self.transform)
# test_dataloader = DataLoader(...)
# trainer.test(ckpt_path="best", dataloaders=)
if __name__ == "__main__":
cli_lightning_logo()
parser = ArgumentParser()
parser.add_argument(
"--data_path", nargs='+', required=True,
help="path where dataset is stored")
parser.add_argument('--ngpus', type=int,
default=torch.cuda.device_count(),
help='number of gpus to use')
parser.add_argument(
"--n-classes", type=int, default=18, help="number of classes")
parser.add_argument(
"--batch_size", type=int, default=256, help="size of the batches")
parser.add_argument(
"--lr", type=float, default=3e-4, help="adam: learning rate")
parser.add_argument(
"--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument(
"--features_start", type=float, default=64,
help="number of features in first layer")
parser.add_argument(
"--bilinear", action="store_true", default=False,
help="whether to use bilinear interpolation or transposed")
# parser.add_argument(
# "--log-wandb", action="store_true", default=True,
# help="whether to use wandb as the logger")
hparams = parser.parse_args()
main(hparams)