File size: 3,251 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from pathlib import Path
from tops.config import LazyCall as L
import torch
import functools
from dp2.data.datasets.fdh import get_dataloader_fdh_wds
from dp2.data.utils import get_coco_flipmap
from dp2.data.transforms.transforms import (
    Normalize,
    ToFloat,
    CreateCondition,
    RandomHorizontalFlip,
    CreateEmbedding,
)
from dp2.metrics.torch_metrics import compute_metrics_iteratively
from dp2.metrics.fid_clip import compute_fid_clip
from dp2.metrics.ppl import calculate_ppl
from .utils import train_eval_fn


def final_eval_fn(*args, **kwargs):
    result = compute_metrics_iteratively(*args, **kwargs)
    result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160))
    result3 = compute_fid_clip(*args, **kwargs)
    assert all(key not in result for key in result2)
    result.update(result2)
    result.update(result3)
    return result


def get_cache_directory(imsize, subset):
    return Path(metrics_cache, f"{subset}{imsize[0]}")

dataset_base_dir = (
    os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
)
metrics_cache = (
    os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
)
data_dir = Path(dataset_base_dir, "fdh")
data = dict(
    imsize=(288, 160),
    im_channels=3,
    cse_nc=16,
    n_keypoints=17,
    train=dict(
        loader=L(get_dataloader_fdh_wds)(
            path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
            batch_size="${train.batch_size}",
            num_workers=6,
            transform=L(torch.nn.Sequential)(
                L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
            ),
            gpu_transform=L(torch.nn.Sequential)(
                L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
                L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
                L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
                L(CreateCondition)(),
            ),
            infinite=True,
            shuffle=True,
            partial_batches=False,
            load_embedding=True,
            keypoints_split="train",
            load_new_keypoints=False
        )
    ),
    val=dict(
        loader=L(get_dataloader_fdh_wds)(
            path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
            batch_size="${train.batch_size}",
            num_workers=6,
            transform=None,
            gpu_transform="${data.train.loader.gpu_transform}",
            infinite=False,
            shuffle=False,
            partial_batches=True,
            load_embedding=True,
            keypoints_split="val",
            load_new_keypoints="${data.train.loader.load_new_keypoints}"
        )
    ),
    # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
    train_evaluation_fn=L(functools.partial)(
        train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"),
        data_len=30_000),
    evaluation_fn=L(functools.partial)(
        final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"), 
        data_len=30_000)
)