Spaces:
Build error
Build error
| import os | |
| import time | |
| import argparse | |
| import yaml, math | |
| from tqdm import trange | |
| import torch | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| import torch.distributed as dist | |
| from pytorch_lightning import seed_everything | |
| from lvdm.samplers.ddim import DDIMSampler | |
| from lvdm.utils.common_utils import str2bool | |
| from lvdm.utils.dist_utils import setup_dist, gather_data | |
| from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d | |
| from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np | |
| # ------------------------------------------------------------------------------------------ | |
| def get_parser(): | |
| parser = argparse.ArgumentParser() | |
| # basic args | |
| parser.add_argument("--ckpt_path", type=str, help="model checkpoint path") | |
| parser.add_argument("--config_path", type=str, help="model config path (a yaml file)") | |
| parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).") | |
| parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/") | |
| # device args | |
| parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False) | |
| parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0) | |
| parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0) | |
| # sampling args | |
| parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2) | |
| parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1) | |
| parser.add_argument("--decode_frame_bs", type=int, help="frame batch size for framewise decoding", default=1) | |
| parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddim", choices=["ddpm", "ddim"]) | |
| parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50) | |
| parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0) | |
| parser.add_argument("--cfg_scale", type=float, default=15.0, help="classifier-free guidance scale") | |
| parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)") | |
| parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",) | |
| # lora args | |
| parser.add_argument("--lora_path", type=str, help="lora checkpoint path") | |
| parser.add_argument("--inject_lora", action='store_true', default=False, help="",) | |
| parser.add_argument("--lora_scale", type=float, default=None, help="scale for lora weight") | |
| parser.add_argument("--lora_trigger_word", type=str, default="", help="",) | |
| # saving args | |
| parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"]) | |
| parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",) | |
| parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",) | |
| parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",) | |
| parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",) | |
| return parser | |
| # ------------------------------------------------------------------------------------------ | |
| def sample_denoising_batch(model, noise_shape, condition, *args, | |
| sample_type="ddim", sampler=None, | |
| ddim_steps=None, eta=None, | |
| unconditional_guidance_scale=1.0, uc=None, | |
| denoising_progress=False, | |
| **kwargs, | |
| ): | |
| if sample_type == "ddpm": | |
| samples = model.p_sample_loop(cond=condition, shape=noise_shape, | |
| return_intermediates=False, | |
| verbose=denoising_progress, | |
| **kwargs, | |
| ) | |
| elif sample_type == "ddim": | |
| assert(sampler is not None) | |
| assert(ddim_steps is not None) | |
| assert(eta is not None) | |
| ddim_sampler = sampler | |
| samples, _ = ddim_sampler.sample(S=ddim_steps, | |
| conditioning=condition, | |
| batch_size=noise_shape[0], | |
| shape=noise_shape[1:], | |
| verbose=denoising_progress, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=uc, | |
| eta=eta, | |
| **kwargs, | |
| ) | |
| else: | |
| raise ValueError | |
| return samples | |
| # ------------------------------------------------------------------------------------------ | |
| def sample_text2video(model, prompt, n_samples, batch_size, | |
| sample_type="ddim", sampler=None, | |
| ddim_steps=50, eta=1.0, cfg_scale=7.5, | |
| decode_frame_bs=1, | |
| ddp=False, all_gather=True, | |
| batch_progress=True, show_denoising_progress=False, | |
| ): | |
| # get cond vector | |
| assert(model.cond_stage_model is not None) | |
| cond_embd = get_conditions(prompt, model, batch_size) | |
| uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None | |
| # sample batches | |
| all_videos = [] | |
| n_iter = math.ceil(n_samples / batch_size) | |
| iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter) | |
| for _ in iterator: | |
| noise_shape = make_model_input_shape(model, batch_size) | |
| samples_latent = sample_denoising_batch(model, noise_shape, cond_embd, | |
| sample_type=sample_type, | |
| sampler=sampler, | |
| ddim_steps=ddim_steps, | |
| eta=eta, | |
| unconditional_guidance_scale=cfg_scale, | |
| uc=uncond_embd, | |
| denoising_progress=show_denoising_progress, | |
| ) | |
| samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False) | |
| # gather samples from multiple gpus | |
| if ddp and all_gather: | |
| data_list = gather_data(samples, return_np=False) | |
| all_videos.extend([torch_to_np(data) for data in data_list]) | |
| else: | |
| all_videos.append(torch_to_np(samples)) | |
| all_videos = np.concatenate(all_videos, axis=0) | |
| assert(all_videos.shape[0] >= n_samples) | |
| return all_videos | |
| # ------------------------------------------------------------------------------------------ | |
| def save_results(videos, save_dir, | |
| save_name="results", save_fps=8, save_mp4=True, | |
| save_npz=False, save_mp4_sheet=False, save_jpg=False | |
| ): | |
| if save_mp4: | |
| save_subdir = os.path.join(save_dir, "videos") | |
| os.makedirs(save_subdir, exist_ok=True) | |
| for i in range(videos.shape[0]): | |
| npz_to_video_grid(videos[i:i+1,...], | |
| os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), | |
| fps=save_fps) | |
| print(f'Successfully saved videos in {save_subdir}') | |
| if save_npz: | |
| save_path = os.path.join(save_dir, f"{save_name}.npz") | |
| np.savez(save_path, videos) | |
| print(f'Successfully saved npz in {save_path}') | |
| if save_mp4_sheet: | |
| save_path = os.path.join(save_dir, f"{save_name}.mp4") | |
| npz_to_video_grid(videos, save_path, fps=save_fps) | |
| print(f'Successfully saved mp4 sheet in {save_path}') | |
| if save_jpg: | |
| save_path = os.path.join(save_dir, f"{save_name}.jpg") | |
| npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1]) | |
| print(f'Successfully saved jpg sheet in {save_path}') | |
| # ------------------------------------------------------------------------------------------ | |
| def main(): | |
| """ | |
| text-to-video generation | |
| """ | |
| parser = get_parser() | |
| opt, unknown = parser.parse_known_args() | |
| os.makedirs(opt.save_dir, exist_ok=True) | |
| # set device | |
| if opt.ddp: | |
| setup_dist(opt.local_rank) | |
| opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size()) | |
| gpu_id = None | |
| else: | |
| gpu_id = opt.gpu_id | |
| os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" | |
| # set random seed | |
| if opt.seed is not None: | |
| if opt.ddp: | |
| seed = opt.local_rank + opt.seed | |
| else: | |
| seed = opt.seed | |
| seed_everything(seed) | |
| # dump args | |
| fpath = os.path.join(opt.save_dir, "sampling_args.yaml") | |
| with open(fpath, 'w') as f: | |
| yaml.dump(vars(opt), f, default_flow_style=False) | |
| # load & merge config | |
| config = OmegaConf.load(opt.config_path) | |
| cli = OmegaConf.from_dotlist(unknown) | |
| config = OmegaConf.merge(config, cli) | |
| print("config: \n", config) | |
| # get model & sampler | |
| model, _, _ = load_model(config, opt.ckpt_path, | |
| inject_lora=opt.inject_lora, | |
| lora_scale=opt.lora_scale, | |
| lora_path=opt.lora_path | |
| ) | |
| ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None | |
| # prepare prompt | |
| if opt.prompt.endswith(".txt"): | |
| opt.prompt_file = opt.prompt | |
| opt.prompt = None | |
| else: | |
| opt.prompt_file = None | |
| if opt.prompt_file is not None: | |
| f = open(opt.prompt_file, 'r') | |
| prompts, line_idx = [], [] | |
| for idx, line in enumerate(f.readlines()): | |
| l = line.strip() | |
| if len(l) != 0: | |
| prompts.append(l) | |
| line_idx.append(idx) | |
| f.close() | |
| cmd = f"cp {opt.prompt_file} {opt.save_dir}" | |
| os.system(cmd) | |
| else: | |
| prompts = [opt.prompt] | |
| line_idx = [None] | |
| if opt.inject_lora: | |
| assert(opt.lora_trigger_word != '') | |
| prompts = [p + opt.lora_trigger_word for p in prompts] | |
| # go | |
| start = time.time() | |
| for prompt in prompts: | |
| # sample | |
| samples = sample_text2video(model, prompt, opt.n_samples, opt.batch_size, | |
| sample_type=opt.sample_type, sampler=ddim_sampler, | |
| ddim_steps=opt.ddim_steps, eta=opt.eta, | |
| cfg_scale=opt.cfg_scale, | |
| decode_frame_bs=opt.decode_frame_bs, | |
| ddp=opt.ddp, show_denoising_progress=opt.show_denoising_progress, | |
| ) | |
| # save | |
| if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp): | |
| prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt | |
| save_name = prompt_str.replace(" ", "_") if " " in prompt else prompt_str | |
| if opt.seed is not None: | |
| save_name = save_name + f"_seed{seed:05d}" | |
| save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps) | |
| print("Finish sampling!") | |
| print(f"Run time = {(time.time() - start):.2f} seconds") | |
| if opt.ddp: | |
| dist.destroy_process_group() | |
| # ------------------------------------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| main() |