|
import os |
|
import logging |
|
from typing import Any, Union, Optional |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.dataset import random_split |
|
from pytorch_lightning import LightningDataModule |
|
from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY |
|
|
|
from terragpu.ai.deep_learning.datasets.segmentation_dataset \ |
|
import SegmentationDataset |
|
|
|
|
|
@DATAMODULE_REGISTRY |
|
class SegmentationDataModule(LightningDataModule): |
|
|
|
def __init__( |
|
self, |
|
|
|
|
|
dataset_dir: str = 'dataset/', |
|
images_regex: str = 'dataset/images/*.tif', |
|
labels_regex: str = 'dataset/labels/*.tif', |
|
generate_dataset: bool = True, |
|
tile_size: int = 256, |
|
max_patches: Union[float, int] = 100, |
|
augment: bool = True, |
|
chunks: dict = {'band': 1, 'x': 2048, 'y': 2048}, |
|
input_bands: list = ['CB', 'B', 'G', 'Y', 'R', 'RE', 'N1', 'N2'], |
|
output_bands: list = ['B', 'G', 'R'], |
|
seed: int = 24, |
|
normalize: bool = True, |
|
pytorch: bool = True, |
|
|
|
|
|
val_split: float = 0.2, |
|
test_split: float = 0.1, |
|
num_workers: int = os.cpu_count(), |
|
batch_size: int = 32, |
|
shuffle: bool = True, |
|
pin_memory: bool = False, |
|
drop_last: bool = False, |
|
|
|
|
|
raster_regex: str = 'rasters/*.tif', |
|
|
|
*args: Any, |
|
**kwargs: Any, |
|
|
|
) -> None: |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self.images_regex = images_regex |
|
self.labels_regex = labels_regex |
|
self.dataset_dir = dataset_dir |
|
self.generate_dataset = generate_dataset |
|
self.tile_size = tile_size |
|
self.max_patches = max_patches |
|
self.augment = augment |
|
self.chunks = chunks |
|
self.input_bands = input_bands |
|
self.output_bands = output_bands |
|
self.seed = seed |
|
self.normalize = normalize |
|
self.pytorch = pytorch |
|
|
|
self.val_split = val_split |
|
self.test_split = test_split |
|
self.raster_regex = raster_regex |
|
|
|
|
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.shuffle = shuffle |
|
self.pin_memory = pin_memory |
|
self.drop_last = drop_last |
|
|
|
def prepare_data(self): |
|
if self.generate_dataset: |
|
SegmentationDataset( |
|
images_regex=self.images_regex, |
|
labels_regex=self.labels_regex, |
|
dataset_dir=self.dataset_dir, |
|
generate_dataset=self.generate_dataset, |
|
tile_size=self.tile_size, |
|
max_patches=self.max_patches, |
|
augment=self.augment, |
|
chunks=self.chunks, |
|
input_bands=self.input_bands, |
|
output_bands=self.output_bands, |
|
seed=self.seed, |
|
normalize=self.normalize, |
|
pytorch=self.pytorch, |
|
) |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
|
|
|
|
segmentation_dataset = SegmentationDataset( |
|
images_regex=self.images_regex, |
|
labels_regex=self.labels_regex, |
|
dataset_dir=self.dataset_dir, |
|
generate_dataset=False, |
|
tile_size=self.tile_size, |
|
max_patches=self.max_patches, |
|
augment=self.augment, |
|
chunks=self.chunks, |
|
input_bands=self.input_bands, |
|
output_bands=self.output_bands, |
|
seed=self.seed, |
|
normalize=self.normalize, |
|
pytorch=self.pytorch, |
|
) |
|
|
|
|
|
val_len = round(self.val_split * len(segmentation_dataset)) |
|
test_len = round(self.test_split * len(segmentation_dataset)) |
|
train_len = len(segmentation_dataset) - val_len - test_len |
|
|
|
|
|
self.train_set, self.val_set, self.test_set = random_split( |
|
segmentation_dataset, lengths=[train_len, val_len, test_len], |
|
generator=torch.Generator().manual_seed(self.seed) |
|
) |
|
logging.info("Initialized datasets...") |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
loader = DataLoader( |
|
self.train_set, |
|
batch_size=self.batch_size, |
|
shuffle=self.shuffle, |
|
num_workers=self.num_workers, |
|
drop_last=self.drop_last, |
|
pin_memory=self.pin_memory, |
|
) |
|
return loader |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
loader = DataLoader( |
|
self.val_set, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
num_workers=self.num_workers, |
|
drop_last=self.drop_last, |
|
pin_memory=self.pin_memory, |
|
) |
|
return loader |
|
|
|
def test_dataloader(self) -> DataLoader: |
|
loader = DataLoader( |
|
self.test_set, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
num_workers=self.num_workers, |
|
drop_last=self.drop_last, |
|
pin_memory=self.pin_memory, |
|
) |
|
return loader |
|
|
|
def predict_dataloader(self) -> DataLoader: |
|
raise NotImplementedError |
|
|