File size: 1,527 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02a0ab
 
 
 
8c212a5
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import sys
from models.genforce.models import MODEL_ZOO
from models.genforce.models import build_generator
import os
import os.path as osp
import subprocess
import torch


def load_generator(model_name, latent_is_w=False, verbose=False, CHECKPOINT_DIR='models/pretrained/genforce/'):

    if verbose:
        print("  \\__Building generator for model {}...".format(model_name), end="")

    model_config = MODEL_ZOO[model_name].copy()
    url = model_config.pop('url')  # URL to download model if needed.
    model_config.update({'latent_is_w': latent_is_w})

    # Build generator
    generator = build_generator(**model_config)
    if verbose:
        print("Done!")

    # Load pre-trained weights.
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    checkpoint_path = osp.join(CHECKPOINT_DIR, model_name + '.pth')

    if verbose:
        print("  \\__Loading checkpoint from {}...".format(checkpoint_path), end="")

    if not osp.exists(checkpoint_path):
        subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
    # checkpoint = torch.load(checkpoint_path, map_location='cpu')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    if 'generator_smooth' in checkpoint:
        generator.load_state_dict(checkpoint['generator_smooth'])
    else:
        generator.load_state_dict(checkpoint['generator'])
    if verbose:
        print("Done!")

    generator.dim_z = generator.z_space_dim

    return generator