File size: 5,171 Bytes
efe586f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from tqdm import trange
import torch

from torch.utils.data import DataLoader

from logger import Logger
from modules.model import GeneratorFullModel, DiscriminatorFullModel

from torch.optim.lr_scheduler import MultiStepLR

from sync_batchnorm import DataParallelWithCallback

from frames_dataset import DatasetRepeater


def train(config, generator, discriminator, kp_detector, he_estimator, checkpoint, log_dir, dataset, device_ids):
    train_params = config["train_params"]

    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params["lr_generator"], betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params["lr_discriminator"], betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params["lr_kp_detector"], betas=(0.5, 0.999))
    optimizer_he_estimator = torch.optim.Adam(he_estimator.parameters(), lr=train_params["lr_he_estimator"], betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch = Logger.load_cpk(
            checkpoint,
            generator,
            discriminator,
            kp_detector,
            he_estimator,
            optimizer_generator,
            optimizer_discriminator,
            optimizer_kp_detector,
            optimizer_he_estimator,
        )
    else:
        start_epoch = 0

    scheduler_generator = MultiStepLR(optimizer_generator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(
        optimizer_kp_detector, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0)
    )
    scheduler_he_estimator = MultiStepLR(
        optimizer_he_estimator, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0)
    )

    if "num_repeats" in train_params or train_params["num_repeats"] != 1:
        dataset = DatasetRepeater(dataset, train_params["num_repeats"])
    dataloader = DataLoader(dataset, batch_size=train_params["batch_size"], shuffle=True, num_workers=16, drop_last=True)

    generator_full = GeneratorFullModel(
        kp_detector,
        he_estimator,
        generator,
        discriminator,
        train_params,
        estimate_jacobian=config["model_params"]["common_params"]["estimate_jacobian"],
    )
    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)

    with Logger(log_dir=log_dir, visualizer_params=config["visualizer_params"], checkpoint_freq=train_params["checkpoint_freq"]) as logger:
        for epoch in trange(start_epoch, train_params["num_epochs"]):
            for x in dataloader:
                losses_generator, generated = generator_full(x)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()
                optimizer_he_estimator.step()
                optimizer_he_estimator.zero_grad()

                if train_params["loss_weights"]["generator_gan"] != 0:
                    optimizer_discriminator.zero_grad()
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [val.mean() for val in losses_discriminator.values()]
                    loss = sum(loss_values)

                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                losses_generator.update(losses_discriminator)
                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
                logger.log_iter(losses=losses)

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()
            scheduler_he_estimator.step()

            logger.log_epoch(
                epoch,
                {
                    "generator": generator,
                    "discriminator": discriminator,
                    "kp_detector": kp_detector,
                    "he_estimator": he_estimator,
                    "optimizer_generator": optimizer_generator,
                    "optimizer_discriminator": optimizer_discriminator,
                    "optimizer_kp_detector": optimizer_kp_detector,
                    "optimizer_he_estimator": optimizer_he_estimator,
                },
                inp=x,
                out=generated,
            )