nicolasni1977's picture
add files
efe586f
import matplotlib
matplotlib.use("Agg")
import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy
from frames_dataset import FramesDataset
from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector, HEEstimator
import torch
from train import train
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
parser = ArgumentParser()
parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
parser.add_argument(
"--mode",
default="train",
choices=[
"train",
],
)
parser.add_argument("--gen", default="original", choices=["original", "spade"])
parser.add_argument("--log_dir", default="log", help="path to log into")
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
parser.add_argument(
"--device_ids",
default="0, 1, 2, 3, 4, 5, 6, 7",
type=lambda x: list(map(int, x.split(","))),
help="Names of the devices comma separated.",
)
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
parser.set_defaults(verbose=False)
opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if opt.checkpoint is not None:
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
else:
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split(".")[0])
log_dir += " " + strftime("%d_%m_%y_%H.%M.%S", gmtime())
if opt.gen == "original":
generator = OcclusionAwareGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"])
elif opt.gen == "spade":
generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
print("cuda is available")
generator.to(opt.device_ids[0])
if opt.verbose:
print(generator)
discriminator = MultiScaleDiscriminator(**config["model_params"]["discriminator_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
discriminator.to(opt.device_ids[0])
if opt.verbose:
print(discriminator)
kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
kp_detector.to(opt.device_ids[0])
if opt.verbose:
print(kp_detector)
he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"])
if torch.cuda.is_available():
he_estimator.to(opt.device_ids[0])
dataset = FramesDataset(is_train=(opt.mode == "train"), **config["dataset_params"])
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
copy(opt.config, log_dir)
if opt.mode == "train":
print("Training...")
train(config, generator, discriminator, kp_detector, he_estimator, opt.checkpoint, log_dir, dataset, opt.device_ids)