|
"""Script to convert officially released models to match this repository.""" |
|
|
|
import argparse |
|
|
|
from converters import convert_pggan_weight |
|
from converters import convert_stylegan_weight |
|
from converters import convert_stylegan2_weight |
|
from converters import convert_stylegan2ada_tf_weight |
|
from converters import convert_stylegan2ada_pth_weight |
|
|
|
|
|
def parse_args(): |
|
"""Parses arguments.""" |
|
parser = argparse.ArgumentParser(description='Convert pre-trained models.') |
|
parser.add_argument('model_type', type=str, |
|
choices=['pggan', 'stylegan', 'stylegan2', |
|
'stylegan2ada_tf', 'stylegan2ada_pth'], |
|
help='Type of the model to convert') |
|
parser.add_argument('--source_model_path', type=str, required=True, |
|
help='Path to load the model for conversion.') |
|
parser.add_argument('--target_model_path', type=str, default=None, |
|
help='Path to save the converted model. If not ' |
|
'specified, the model will be saved to the same ' |
|
'directory of the source model.') |
|
parser.add_argument('--test_num', type=int, default=10, |
|
help='Number of test samples used to check the ' |
|
'precision of the converted model. (default: 10)') |
|
parser.add_argument('--save_test_image', action='store_true', |
|
help='Whether to save the test image. (default: False)') |
|
parser.add_argument('--verbose_log', action='store_true', |
|
help='Whether to print verbose log. (default: False)') |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
"""Main function.""" |
|
args = parse_args() |
|
if args.target_model_path is None: |
|
args.target_model_path = args.source_model_path.replace('.pkl', '.pth') |
|
|
|
if args.model_type == 'pggan': |
|
convert_pggan_weight(tf_weight_path=args.source_model_path, |
|
pth_weight_path=args.target_model_path, |
|
test_num=args.test_num, |
|
save_test_image=args.save_test_image, |
|
verbose=args.verbose_log) |
|
elif args.model_type == 'stylegan': |
|
convert_stylegan_weight(tf_weight_path=args.source_model_path, |
|
pth_weight_path=args.target_model_path, |
|
test_num=args.test_num, |
|
save_test_image=args.save_test_image, |
|
verbose=args.verbose_log) |
|
elif args.model_type == 'stylegan2': |
|
convert_stylegan2_weight(tf_weight_path=args.source_model_path, |
|
pth_weight_path=args.target_model_path, |
|
test_num=args.test_num, |
|
save_test_image=args.save_test_image, |
|
verbose=args.verbose_log) |
|
elif args.model_type == 'stylegan2ada_tf': |
|
convert_stylegan2ada_tf_weight(tf_weight_path=args.source_model_path, |
|
pth_weight_path=args.target_model_path, |
|
test_num=args.test_num, |
|
save_test_image=args.save_test_image, |
|
verbose=args.verbose_log) |
|
elif args.model_type == 'stylegan2ada_pth': |
|
convert_stylegan2ada_pth_weight(src_weight_path=args.source_model_path, |
|
dst_weight_path=args.target_model_path, |
|
test_num=args.test_num, |
|
save_test_image=args.save_test_image, |
|
verbose=args.verbose_log) |
|
else: |
|
raise NotImplementedError(f'Model type `{args.model_type}` is not ' |
|
f'supported!') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|