""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import torch from diffusion import gaussian_diffusion as gd from diffusion.respace import space_timesteps, SpacedDiffusion from model.diffusion import FiLMTransformer from torch.nn import functional as F def get_person_num(config_path): if "PXB184" in config_path: person = "PXB184" elif "RLW104" in config_path: person = "RLW104" elif "TXB805" in config_path: person = "TXB805" elif "GQS883" in config_path: person = "GQS883" else: assert False, f"something wrong with config: {config_path}" return person def load_model(model, state_dict): missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) assert len(unexpected_keys) == 0, unexpected_keys assert all( [ k.startswith("transformer.") or k.startswith("tokenizer.") for k in missing_keys ] ), missing_keys def create_model_and_diffusion(args, split_type): model = FiLMTransformer(**get_model_args(args, split_type=split_type)).to( torch.float32 ) diffusion = create_gaussian_diffusion(args) return model, diffusion def get_model_args(args, split_type): if args.data_format == "face": nfeat = 256 lfeat = 512 elif args.data_format == "pose": nfeat = 104 lfeat = 256 if not hasattr(args, "num_audio_layers"): args.num_audio_layers = 3 # backwards compat model_args = { "args": args, "nfeats": nfeat, "latent_dim": lfeat, "ff_size": 1024, "num_layers": args.layers, "num_heads": args.heads, "dropout": 0.1, "cond_feature_dim": 512 * 2, "activation": F.gelu, "use_rotary": not args.not_rotary, "cond_mode": "uncond" if args.unconstrained else "audio", "split_type": split_type, "num_audio_layers": args.num_audio_layers, "device": args.device, } return model_args def create_gaussian_diffusion(args): predict_xstart = True steps = 1000 scale_beta = 1.0 timestep_respacing = args.timestep_respacing learn_sigma = False rescale_timesteps = False betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) loss_type = gd.LossType.MSE if not timestep_respacing: timestep_respacing = [steps] name = args.save_dir if hasattr(args, "save_dir") else args.model_path return SpacedDiffusion( use_timesteps=space_timesteps(steps, timestep_respacing), betas=betas, model_mean_type=( gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X ), model_var_type=( ( gd.ModelVarType.FIXED_LARGE if not args.sigma_small else gd.ModelVarType.FIXED_SMALL ) if not learn_sigma else gd.ModelVarType.LEARNED_RANGE ), data_format=args.data_format, loss_type=loss_type, rescale_timesteps=rescale_timesteps, lambda_vel=args.lambda_vel, model_path=name, )