from ..datasets.simmim_modis_dataset import MODISDataset from ..transforms import SimmimTransform import torch.distributed as dist from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data._utils.collate import default_collate DATASETS = { 'MODIS': MODISDataset, } def collate_fn(batch): if not isinstance(batch[0][0], tuple): return default_collate(batch) else: batch_num = len(batch) ret = [] for item_idx in range(len(batch[0][0])): if batch[0][0][item_idx] is None: ret.append(None) else: ret.append(default_collate( [batch[i][0][item_idx] for i in range(batch_num)])) ret.append(default_collate([batch[i][1] for i in range(batch_num)])) return ret def get_dataset_from_dict(dataset_name): try: dataset_to_use = DATASETS[dataset_name] except KeyError: error_msg = f"{dataset_name} is not an existing dataset" error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" raise KeyError(error_msg) return dataset_to_use def build_mim_dataloader(config, logger): transform = SimmimTransform(config) logger.info(f'Pre-train data transform:\n{transform}') dataset_name = config.DATA.DATASET dataset_to_use = get_dataset_from_dict(dataset_name) dataset = dataset_to_use(config, config.DATA.DATA_PATHS, split="train", img_size=config.DATA.IMG_SIZE, transform=transform) logger.info(f'Build dataset: train images = {len(dataset)}') sampler = DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) return dataloader