FormFighterAIStack / lib /data /dataloader.py
Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
1.29 kB
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
from .datasets import EvalDataset, DataFactory
from ..utils.data_utils import make_collate_fn
def setup_eval_dataloader(cfg, data, split='test', backbone=None):
if backbone is None:
backbone = cfg.MODEL.BACKBONE
dataset = EvalDataset(cfg, data, split, backbone)
dloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=0,
shuffle=False,
pin_memory=True,
collate_fn=make_collate_fn()
)
return dloader
def setup_train_dataloader(cfg, ):
n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS
train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE)
dloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=n_workers,
shuffle=True,
pin_memory=True,
collate_fn=make_collate_fn()
)
return dloader
def setup_dloaders(cfg, dset='3dpw', split='val'):
test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE)
train_dloader = setup_train_dataloader(cfg)
return train_dloader, test_dloader