File size: 2,366 Bytes
09b47fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

def continue_training(checkpoint_path, generator: DDP, mpd: DDP, mrd: DDP, optimizer_d: optim.Optimizer, optimizer_g: optim.Optimizer) -> int:
    """load the latest checkpoints and optimizers"""
    generator_dict = {}
    mpd_dict = {}
    mrd_dict = {}
    optimizer_d_dict = {}
    optimizer_g_dict = {}
    
    # globt all the checkpoints in the directory
    for file in os.listdir(checkpoint_path):
        if file.endswith(".pt"):
            name, epoch_str = file.rsplit('_', 1)
            epoch = int(epoch_str.split('.')[0])
            
            if name.startswith("generator"):
                generator_dict[epoch] = file
            elif name.startswith("mpd"):
                mpd_dict[epoch] = file
            elif name.startswith("mrd"):
                mrd_dict[epoch] = file
            elif name.startswith("optimizerd"):
                optimizer_d_dict[epoch] = file
            elif name.startswith("optimizerg"):
                optimizer_g_dict[epoch] = file
    
    # get the largest epoch
    common_epochs = set(generator_dict.keys()) & set(mpd_dict.keys()) & set(mrd_dict.keys()) & set(optimizer_d_dict.keys()) & set(optimizer_g_dict.keys())
    if common_epochs:
        max_epoch = max(common_epochs)
        generator_path = os.path.join(checkpoint_path, generator_dict[max_epoch])
        mpd_path = os.path.join(checkpoint_path, mpd_dict[max_epoch])
        mrd_path = os.path.join(checkpoint_path, mrd_dict[max_epoch])
        optimizer_d_path = os.path.join(checkpoint_path, optimizer_d_dict[max_epoch])
        optimizer_g_path = os.path.join(checkpoint_path, optimizer_g_dict[max_epoch])
        
        # load model and optimizer
        generator.module.load_state_dict(torch.load(generator_path, map_location='cpu'))
        mpd.module.load_state_dict(torch.load(mpd_path, map_location='cpu'))
        mrd.module.load_state_dict(torch.load(mrd_path, map_location='cpu'))
        optimizer_d.load_state_dict(torch.load(optimizer_d_path, map_location='cpu'))
        optimizer_g.load_state_dict(torch.load(optimizer_g_path, map_location='cpu'))
        
        print(f'resume model and optimizer from {max_epoch} epoch')
        return max_epoch + 1
    
    else:
        return 0