File size: 2,320 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
from pathlib import Path
from tops.config import LazyCall as L
import torch
import functools
from dp2.data.datasets.fdf import FDF256Dataset
from dp2.data.build import get_dataloader
from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
from .utils import final_eval_fn, train_eval_fn


dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
data_dir = Path(dataset_base_dir, "fdf256")
data = dict(
    imsize=(256, 256),
    im_channels=3,
    semantic_nc=None,
    cse_nc=None,
    n_keypoints=None,
    train=dict(
        dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
        loader=L(get_dataloader)(
            shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
            batch_size="${train.batch_size}",
            dataset="${..dataset}",
            infinite=True,
            gpu_transform=L(torch.nn.Sequential)(*[
                L(ToFloat)(),
                L(RandomHorizontalFlip)(p=0.5),
                L(Resize)(size="${data.imsize}"),
                L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
                L(CreateCondition)(),
            ])
        )
    ),
    val=dict(
        dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
        loader=L(get_dataloader)(
            shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
            batch_size="${train.batch_size}",
            dataset="${..dataset}",
            infinite=False,
            gpu_transform=L(torch.nn.Sequential)(*[
                L(ToFloat)(),
                L(Resize)(size="${data.imsize}"),
                L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
                L(CreateCondition)(),
            ])
        )
    ),
    # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
    train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")),
    evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
)