File size: 2,211 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from ..datasets.simmim_modis_dataset import MODISDataset
from ..transforms import SimmimTransform
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate
DATASETS = {
'MODIS': MODISDataset,
}
def collate_fn(batch):
if not isinstance(batch[0][0], tuple):
return default_collate(batch)
else:
batch_num = len(batch)
ret = []
for item_idx in range(len(batch[0][0])):
if batch[0][0][item_idx] is None:
ret.append(None)
else:
ret.append(default_collate(
[batch[i][0][item_idx] for i in range(batch_num)]))
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
return ret
def get_dataset_from_dict(dataset_name):
try:
dataset_to_use = DATASETS[dataset_name]
except KeyError:
error_msg = f"{dataset_name} is not an existing dataset"
error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}"
raise KeyError(error_msg)
return dataset_to_use
def build_mim_dataloader(config, logger):
transform = SimmimTransform(config)
logger.info(f'Pre-train data transform:\n{transform}')
dataset_name = config.DATA.DATASET
dataset_to_use = get_dataset_from_dict(dataset_name)
dataset = dataset_to_use(config,
config.DATA.DATA_PATHS,
split="train",
img_size=config.DATA.IMG_SIZE,
transform=transform)
logger.info(f'Build dataset: train images = {len(dataset)}')
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True)
dataloader = DataLoader(dataset,
config.DATA.BATCH_SIZE,
sampler=sampler,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn)
return dataloader
|