Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
from .datasets import EvalDataset, DataFactory | |
from ..utils.data_utils import make_collate_fn | |
def setup_eval_dataloader(cfg, data, split='test', backbone=None): | |
if backbone is None: | |
backbone = cfg.MODEL.BACKBONE | |
dataset = EvalDataset(cfg, data, split, backbone) | |
dloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=1, | |
num_workers=0, | |
shuffle=False, | |
pin_memory=True, | |
collate_fn=make_collate_fn() | |
) | |
return dloader | |
def setup_train_dataloader(cfg, ): | |
n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS | |
train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE) | |
dloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=cfg.TRAIN.BATCH_SIZE, | |
num_workers=n_workers, | |
shuffle=True, | |
pin_memory=True, | |
collate_fn=make_collate_fn() | |
) | |
return dloader | |
def setup_dloaders(cfg, dset='3dpw', split='val'): | |
test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE) | |
train_dloader = setup_train_dataloader(cfg) | |
return train_dloader, test_dloader |