File size: 3,951 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Script to convert officially released models to match this repository."""

import argparse

from converters import convert_pggan_weight
from converters import convert_stylegan_weight
from converters import convert_stylegan2_weight
from converters import convert_stylegan2ada_tf_weight
from converters import convert_stylegan2ada_pth_weight


def parse_args():
    """Parses arguments."""
    parser = argparse.ArgumentParser(description='Convert pre-trained models.')
    parser.add_argument('model_type', type=str,
                        choices=['pggan', 'stylegan', 'stylegan2',
                                 'stylegan2ada_tf', 'stylegan2ada_pth'],
                        help='Type of the model to convert')
    parser.add_argument('--source_model_path', type=str, required=True,
                        help='Path to load the model for conversion.')
    parser.add_argument('--target_model_path', type=str, default=None,
                        help='Path to save the converted model. If not '
                             'specified, the model will be saved to the same '
                             'directory of the source model.')
    parser.add_argument('--test_num', type=int, default=10,
                        help='Number of test samples used to check the '
                             'precision of the converted model. (default: 10)')
    parser.add_argument('--save_test_image', action='store_true',
                        help='Whether to save the test image. (default: False)')
    parser.add_argument('--verbose_log', action='store_true',
                        help='Whether to print verbose log. (default: False)')
    return parser.parse_args()


def main():
    """Main function."""
    args = parse_args()
    if args.target_model_path is None:
        args.target_model_path = args.source_model_path.replace('.pkl', '.pth')

    if args.model_type == 'pggan':
        convert_pggan_weight(tf_weight_path=args.source_model_path,
                             pth_weight_path=args.target_model_path,
                             test_num=args.test_num,
                             save_test_image=args.save_test_image,
                             verbose=args.verbose_log)
    elif args.model_type == 'stylegan':
        convert_stylegan_weight(tf_weight_path=args.source_model_path,
                                pth_weight_path=args.target_model_path,
                                test_num=args.test_num,
                                save_test_image=args.save_test_image,
                                verbose=args.verbose_log)
    elif args.model_type == 'stylegan2':
        convert_stylegan2_weight(tf_weight_path=args.source_model_path,
                                 pth_weight_path=args.target_model_path,
                                 test_num=args.test_num,
                                 save_test_image=args.save_test_image,
                                 verbose=args.verbose_log)
    elif args.model_type == 'stylegan2ada_tf':
        convert_stylegan2ada_tf_weight(tf_weight_path=args.source_model_path,
                                       pth_weight_path=args.target_model_path,
                                       test_num=args.test_num,
                                       save_test_image=args.save_test_image,
                                       verbose=args.verbose_log)
    elif args.model_type == 'stylegan2ada_pth':
        convert_stylegan2ada_pth_weight(src_weight_path=args.source_model_path,
                                        dst_weight_path=args.target_model_path,
                                        test_num=args.test_num,
                                        save_test_image=args.save_test_image,
                                        verbose=args.verbose_log)
    else:
        raise NotImplementedError(f'Model type `{args.model_type}` is not '
                                  f'supported!')


if __name__ == '__main__':
    main()