File size: 1,287 Bytes
f561f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch

from .datasets import EvalDataset, DataFactory
from ..utils.data_utils import make_collate_fn


def setup_eval_dataloader(cfg, data, split='test', backbone=None):
    if backbone is None:
        backbone = cfg.MODEL.BACKBONE
    
    dataset = EvalDataset(cfg, data, split, backbone)
    dloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        num_workers=0,
        shuffle=False,
        pin_memory=True,
        collate_fn=make_collate_fn()
    )
    return dloader


def setup_train_dataloader(cfg, ):
    n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS
    
    train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE)
    dloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=n_workers,
        shuffle=True,
        pin_memory=True,
        collate_fn=make_collate_fn()
    )
    return dloader


def setup_dloaders(cfg, dset='3dpw', split='val'):
    test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE)
    train_dloader = setup_train_dataloader(cfg)
    
    return train_dloader, test_dloader