File size: 3,036 Bytes
548d634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1cdc0f
548d634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import tops
import torch
from tops import checkpointer
from tops.config import instantiate
from tops.logger import warn

def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
    state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
    if ckpt_mapper is not None:
        state = ckpt_mapper(state)
    load_state_dict(G, state)
    tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
    print(ckpt.keys())
    if "w_centers" in ckpt:
        print("Has w_centers!")
        G.style_net.w_centers = ckpt["w_centers"]
        tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")


def build_trained_generator(cfg, map_location=None):
    map_location = map_location if map_location is not None else tops.get_device()
    G = instantiate(cfg.generator).to(map_location)
    G.eval()
    G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
    if hasattr(cfg, "ckpt_mapper"):
        ckpt_mapper = instantiate(cfg.ckpt_mapper)
    else:
        ckpt_mapper = None
    if "model_url" in cfg.common:
        ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum, map_location=torch.device("cpu"))
        load_generator_state(ckpt, G, ckpt_mapper)
        return G
    try:
        ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
        load_generator_state(ckpt, G, ckpt_mapper)
    except FileNotFoundError as e:
        tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
    return G


def build_trained_discriminator(cfg, map_location=None):
    map_location = map_location if map_location is not None else tops.get_device()
    D = instantiate(cfg.discriminator).to(map_location)
    D.eval()
    try:
        ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
        if hasattr(cfg, "ckpt_mapper_D"):
            ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
        D.load_state_dict(ckpt["discriminator"])
    except FileNotFoundError as e:
        tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
    return D


def load_state_dict(module: torch.nn.Module, state_dict: dict):
    module_sd = module.state_dict()
    to_remove = []
    for key, item in state_dict.items():
        if key not in module_sd:
            continue
        if item.shape != module_sd[key].shape:
            to_remove.append(key)
            warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
    for key in to_remove:
        state_dict.pop(key)
    for key, item in state_dict.items():
        if key not in module_sd:
            warn(f"Did not fin key in model state dict: {key}")
    for key, item in module_sd.items():
        if key not in state_dict:
            warn(f"Did not find key in state dict: {key}")
    module.load_state_dict(state_dict, strict=False)