Spaces:
Runtime error
Runtime error
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import os | |
| import sys | |
| try: | |
| import utils | |
| from diffusion import create_diffusion | |
| except: | |
| sys.path.append(os.path.split(sys.path[0])[0]) | |
| import utils | |
| from diffusion import create_diffusion | |
| import argparse | |
| import torchvision | |
| from PIL import Image | |
| from einops import rearrange | |
| from models import get_models | |
| from diffusers.models import AutoencoderKL | |
| from models.clip import TextEmbedder | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning import seed_everything | |
| from utils import mask_generation_before | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
| from vlogger.videofusion import fusion | |
| from vlogger.videocaption import captioning | |
| from vlogger.videoaudio import make_audio, merge_video_audio, concatenate_videos | |
| from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model | |
| from vlogger.planning_utils.gpt4_utils import (readscript, | |
| readtimescript, | |
| readprotagonistscript, | |
| readreferencescript, | |
| readzhscript) | |
| def auto_inpainting(args, | |
| video_input, | |
| masked_video, | |
| mask, | |
| prompt, | |
| image, | |
| vae, | |
| text_encoder, | |
| image_encoder, | |
| diffusion, | |
| model, | |
| device, | |
| ): | |
| image_prompt_embeds = None | |
| if prompt is None: | |
| prompt = "" | |
| if image is not None: | |
| clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values | |
| clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds | |
| uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device) | |
| image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0) | |
| image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous() | |
| model = ip_scale_set(model, args.ref_cfg_scale) | |
| if args.use_fp16: | |
| image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16) | |
| b, f, c, h, w = video_input.shape | |
| latent_h = video_input.shape[-2] // 8 | |
| latent_w = video_input.shape[-1] // 8 | |
| if args.use_fp16: | |
| z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w | |
| masked_video = masked_video.to(dtype=torch.float16) | |
| mask = mask.to(dtype=torch.float16) | |
| else: | |
| z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w | |
| masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() | |
| masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) | |
| masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() | |
| mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) | |
| masked_video = torch.cat([masked_video] * 2) | |
| mask = torch.cat([mask] * 2) | |
| z = torch.cat([z] * 2) | |
| prompt_all = [prompt] + [args.negative_prompt] | |
| text_prompt = text_encoder(text_prompts=prompt_all, train=False) | |
| model_kwargs = dict(encoder_hidden_states=text_prompt, | |
| class_labels=None, | |
| cfg_scale=args.cfg_scale, | |
| use_fp16=args.use_fp16, | |
| ip_hidden_states=image_prompt_embeds) | |
| # Sample images: | |
| samples = diffusion.ddim_sample_loop(model.forward_with_cfg, | |
| z.shape, | |
| z, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| progress=True, | |
| device=device, | |
| mask=mask, | |
| x_start=masked_video, | |
| use_concat=True, | |
| ) | |
| samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] | |
| if args.use_fp16: | |
| samples = samples.to(dtype=torch.float16) | |
| video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] | |
| video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] | |
| return video_clip | |
| def main(args): | |
| # Setup PyTorch: | |
| if args.seed: | |
| torch.manual_seed(args.seed) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| seed_everything(args.seed) | |
| model = get_models(args).to(device) | |
| model = tca_transform_model(model).to(device) | |
| model = ip_transform_model(model).to(device) | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| model.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| if args.use_compile: | |
| model = torch.compile(model) | |
| ckpt_path = args.ckpt | |
| state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] | |
| model_dict = model.state_dict() | |
| pretrained_dict = {} | |
| for k, v in state_dict.items(): | |
| if k in model_dict: | |
| pretrained_dict[k] = v | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |
| model.eval() # important! | |
| diffusion = create_diffusion(str(args.num_sampling_steps)) | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device) | |
| text_encoder = text_encoder = TextEmbedder(args.pretrained_model_path).to(device) | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device) | |
| if args.use_fp16: | |
| print('Warnning: using half percision for inferencing!') | |
| vae.to(dtype=torch.float16) | |
| model.to(dtype=torch.float16) | |
| text_encoder.to(dtype=torch.float16) | |
| print("model ready!\n", flush=True) | |
| # load protagonist script | |
| character_places = readprotagonistscript(args.protagonist_file_path) | |
| print("protagonists ready!", flush=True) | |
| # load script | |
| video_list = readscript(args.script_file_path) | |
| print("video script ready!", flush=True) | |
| # load reference script | |
| reference_lists = readreferencescript(video_list, character_places, args.reference_file_path) | |
| print("reference script ready!", flush=True) | |
| # load zh script | |
| zh_video_list = readzhscript(args.zh_script_file_path) | |
| print("zh script ready!", flush=True) | |
| # load time script | |
| key_list = [] | |
| for key, value in character_places.items(): | |
| key_list.append(key) | |
| time_list = readtimescript(args.time_file_path) | |
| print("time script ready!", flush=True) | |
| # generation begin | |
| sample_list = [] | |
| for i, text_prompt in enumerate(video_list): | |
| sample_list.append([]) | |
| for time in range(time_list[i]): | |
| if time == 0: | |
| print('Generating the ({}) prompt'.format(text_prompt), flush=True) | |
| if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list): | |
| pil_image = None | |
| else: | |
| pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1]) | |
| pil_image.resize((256, 256)) | |
| video_input = torch.zeros([1, 16, 3, args.image_size[0], args.image_size[1]]).to(device) | |
| mask = mask_generation_before("first0", video_input.shape, video_input.dtype, device) # b,f,c,h,w | |
| masked_video = video_input * (mask == 0) | |
| samples = auto_inpainting(args, | |
| video_input, | |
| masked_video, | |
| mask, | |
| text_prompt, | |
| pil_image, | |
| vae, | |
| text_encoder, | |
| image_encoder, | |
| diffusion, | |
| model, | |
| device, | |
| ) | |
| sample_list[i].append(samples) | |
| else: | |
| if sum(video.shape[0] for video in sample_list[i]) / args.fps >= time_list[i]: | |
| break | |
| print('Generating the ({}) prompt'.format(text_prompt), flush=True) | |
| if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list): | |
| pil_image = None | |
| else: | |
| pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1]) | |
| pil_image.resize((256, 256)) | |
| pre_video = sample_list[i][-1][-args.researve_frame:] | |
| f, c, h, w = pre_video.shape | |
| lat_video = torch.zeros(args.num_frames - args.researve_frame, c, h, w).to(device) | |
| video_input = torch.concat([pre_video, lat_video], dim=0) | |
| video_input = video_input.to(device).unsqueeze(0) | |
| mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) | |
| masked_video = video_input * (mask == 0) | |
| video_clip = auto_inpainting(args, | |
| video_input, | |
| masked_video, | |
| mask, | |
| text_prompt, | |
| pil_image, | |
| vae, | |
| text_encoder, | |
| image_encoder, | |
| diffusion, | |
| model, | |
| device, | |
| ) | |
| sample_list[i].append(video_clip[args.researve_frame:]) | |
| print(video_clip[args.researve_frame:].shape) | |
| # transition | |
| if args.video_transition and i != 0: | |
| video_1 = sample_list[i - 1][-1][-1:] | |
| video_2 = sample_list[i][0][:1] | |
| f, c, h, w = video_1.shape | |
| video_middle = torch.zeros(args.num_frames - 2, c, h, w).to(device) | |
| video_input = torch.concat([video_1, video_middle, video_2], dim=0) | |
| video_input = video_input.to(device).unsqueeze(0) | |
| mask = mask_generation_before("onelast1", video_input.shape, video_input.dtype, device) | |
| masked_video = masked_video = video_input * (mask == 0) | |
| video_clip = auto_inpainting(args, | |
| video_input, | |
| masked_video, | |
| mask, | |
| "smooth transition, slow motion, slow changing.", | |
| pil_image, | |
| vae, | |
| text_encoder, | |
| image_encoder, | |
| diffusion, | |
| model, | |
| device, | |
| ) | |
| sample_list[i].insert(0, video_clip[1:-1]) | |
| # save videos | |
| samples = torch.concat(sample_list[i], dim=0) | |
| samples = samples[0: time_list[i] * args.fps] | |
| if not os.path.exists(args.save_origin_video_path): | |
| os.makedirs(args.save_origin_video_path) | |
| video_ = ((samples * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() | |
| torchvision.io.write_video(args.save_origin_video_path + "/" + f"{i}" + '.mp4', video_, fps=args.fps) | |
| # post processing | |
| fusion(args.save_origin_video_path) | |
| captioning(args.script_file_path, args.zh_script_file_path, args.save_origin_video_path, args.save_caption_video_path) | |
| fusion(args.save_caption_video_path) | |
| make_audio(args.script_file_path, args.save_audio_path) | |
| merge_video_audio(args.save_caption_video_path, args.save_audio_path, args.save_audio_caption_video_path) | |
| concatenate_videos(args.save_audio_caption_video_path) | |
| print('final video save path {}'.format(args.save_audio_caption_video_path)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="configs/vlog_read_script_sample.yaml") | |
| args = parser.parse_args() | |
| omega_conf = OmegaConf.load(args.config) | |
| save_path = omega_conf.save_path | |
| save_origin_video_path = os.path.join(save_path, "origin_video") | |
| save_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "caption_video") | |
| save_audio_path = os.path.join(save_path.rsplit('/', 1)[0], "audio") | |
| save_audio_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "audio_caption_video") | |
| if omega_conf.sample_num is not None: | |
| for i in range(omega_conf.sample_num): | |
| omega_conf.save_origin_video_path = save_origin_video_path + f'-{i}' | |
| omega_conf.save_caption_video_path = save_caption_video_path + f'-{i}' | |
| omega_conf.save_audio_path = save_audio_path + f'-{i}' | |
| omega_conf.save_audio_caption_video_path = save_audio_caption_video_path + f'-{i}' | |
| omega_conf.seed += i | |
| main(omega_conf) | |
| else: | |
| omega_conf.save_origin_video_path = save_origin_video_path | |
| omega_conf.save_caption_video_path = save_caption_video_path | |
| omega_conf.save_audio_path = save_audio_path | |
| omega_conf.save_audio_caption_video_path = save_audio_caption_video_path | |
| main(omega_conf) | |