Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Merge branch 'main' of https://huggingface.co/spaces/ychenhq/VideoCrafterXen
Browse files- cog.yaml +25 -0
- final-year-project-443dd-df6f48af0796.json +13 -0
- predict.py +155 -0
- requirements.txt +24 -0
    	
        cog.yaml
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Configuration for Cog ⚙️
         | 
| 2 | 
            +
            # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            build:
         | 
| 5 | 
            +
              gpu: true
         | 
| 6 | 
            +
              system_packages:
         | 
| 7 | 
            +
                - "libgl1-mesa-glx"
         | 
| 8 | 
            +
                - "libglib2.0-0"
         | 
| 9 | 
            +
              python_version: "3.11"
         | 
| 10 | 
            +
              python_packages:
         | 
| 11 | 
            +
                - "torch==2.0.1"
         | 
| 12 | 
            +
                - "opencv-python==4.8.1.78"
         | 
| 13 | 
            +
                - "torchvision==0.15.2"
         | 
| 14 | 
            +
                - "pytorch_lightning==2.1.0"
         | 
| 15 | 
            +
                - "einops==0.7.0"
         | 
| 16 | 
            +
                - "imageio==2.31.6"
         | 
| 17 | 
            +
                - "omegaconf==2.3.0"
         | 
| 18 | 
            +
                - "transformers==4.35.0"
         | 
| 19 | 
            +
                - "moviepy==1.0.3"
         | 
| 20 | 
            +
                - "av==10.0.0"
         | 
| 21 | 
            +
                - "decord==0.6.0"
         | 
| 22 | 
            +
                - "kornia==0.7.0"
         | 
| 23 | 
            +
                - "open-clip-torch==2.12.0"
         | 
| 24 | 
            +
                - "xformers==0.0.21"
         | 
| 25 | 
            +
            predict: "predict.py:Predictor"
         | 
    	
        final-year-project-443dd-df6f48af0796.json
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "type": "service_account",
         | 
| 3 | 
            +
              "project_id": "final-year-project-443dd",
         | 
| 4 | 
            +
              "private_key_id": "df6f48af0796ab27ae03fb99d08afca2ac2b00ef",
         | 
| 5 | 
            +
              "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCmOla5Gssdx196\n4OZyrsE1so4q3nc1fWjNs9PsQ/cm6lTHTENAMM4yHbr0no4b5jL2KgBFwAsIMAMI\nzmHJNc+r/3dnLcPOvnUH8PlkaZNpH/5eQueLz8is7QcqvtnImkg/v2wlXLXWKwWx\nlWyvW10UuYry5qsta3aclqxmhP1jem6QnQxKLiQUNdAPbqsbFyEA11QHzivsTAac\nGdDHF2V/yJ05dqRE+40EaFYbzTXHUBglC0SbgGL512KvpSC16qwFBbY9oy+jHQ55\n8uzVVw5OCSmMCI+UmOrMSe/sI67jHXgOK/GexrHNazh2XbZUSupPIIz1lsBXUl1D\n8L3UdiWVAgMBAAECggEAJwZnOcnaicE230hRkfcJESw8SEA2SG6K3lArnrpOGerF\nwIxc9YL/xbBJJgjbYB1pNXWi3r05WdC7xaN+PZjOipjNVYHfCHiaTST7x+EpZHLI\nayTV63L6r+5t0lFAG+Jst9qe7x6W6hLroUdtXrXaYnU089XHtkAWdqjBDMiIHIRO\nZM9fAnCK/0dShYa0oD1BrjrGCUDrYdJ9I3WJWU+LHBfTZfLXEWbKeE+6665bC7IY\nB9JqhMlbNJWqNwIrg/bB8lI1qIGBY7lEl32N4cQ/JXXpOtfZGx7EAlYiez+bbgnI\nbJN637gp95E8V4l1eSDoF4FdIiygVcghXavOz+AHQQKBgQDmD8NjgkZQ9iiD+1kM\nJUi5AY+xgwOPfR+/vQSM2XWe5Q2jKOR82327Hj3bgua9pWr5FlPRFOakHIohV6nx\nFHkU9LVFwA9tL2pbs+kditDwg8doJtU/wpUW9kYhJ1MAY6dyuRr53CT4XIscXlKX\nHlOK5NClSNY0wFdgIxrQ3vGR/QKBgQC4+Cb2/Chsuh2jt0mp5IESYk38f9E4/YA3\n/1m8aQIbEUfhT3Xihk/MyhOp5MisnACt4kBH2KnrFzB1FAXtAgJQMvP2hLZekTQs\nhYMD2MfsT+E1Fj/bquIh4rDmrAW2wal+HzFBcuqBo81xXrokZGood9TnDNwwow1f\nMus3AXNJeQKBgGaVqtNpWL9rNB+96TQQQAA24QMPX3wRGCIgP7IqmVcT3ePeLRw7\npzHTx1NlaEwyQaP2P8OgZUPScglyFJYqQd+FSntiq75NAUkIzS7eIlLNABLCFh7L\nPj2x7Q2Fgm5PAXCXd57oehfA9ErfCEbYP/pUE3FQLCvzhEKbBK8UanVlAoGBAIkk\nPEedmB9dMwKir/ROHsDRsD7JSgf2NK3QHumJ9ey5uFC+iIoGyX3uSfwKTBtmoz5J\nZR2f8AQFMoFr8iTS+4IY9TdPGKQvBr8H0qb0gO6eHz0sHPay0W0MVdsBqk7hcdi4\nKd375RFvsLAg6uR2qxsMFgelSlCpZA20hB9JbQAJAoGAEmCK/A7k4AJq0cWtad3y\n9wmUsvGFZUhqj1nYtZ2GchKWIcszM28G77AnT52vPNjSDfygQAVxQ7NSYIcwULiA\nMHL4pB8RQr6P4yXISh7dPG8dlrhefrm4KdVMZPOz0Cpry4KejYWKx/YMjqZxARDd\nZFRtycZMdS8kBvSHeyc4mH8=\n-----END PRIVATE KEY-----\n",
         | 
| 6 | 
            +
              "client_email": "firebase-adminsdk-74lss@final-year-project-443dd.iam.gserviceaccount.com",
         | 
| 7 | 
            +
              "client_id": "104174452867915111710",
         | 
| 8 | 
            +
              "auth_uri": "https://accounts.google.com/o/oauth2/auth",
         | 
| 9 | 
            +
              "token_uri": "https://oauth2.googleapis.com/token",
         | 
| 10 | 
            +
              "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
         | 
| 11 | 
            +
              "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-74lss%40final-year-project-443dd.iam.gserviceaccount.com",
         | 
| 12 | 
            +
              "universe_domain": "googleapis.com"
         | 
| 13 | 
            +
            }
         | 
    	
        predict.py
    ADDED
    
    | @@ -0,0 +1,155 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Prediction interface for Cog ⚙️
         | 
| 2 | 
            +
            # https://github.com/replicate/cog/blob/main/docs/python.md
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
            import random
         | 
| 9 | 
            +
            from omegaconf import OmegaConf
         | 
| 10 | 
            +
            from einops import rearrange, repeat
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torchvision
         | 
| 13 | 
            +
            from pytorch_lightning import seed_everything
         | 
| 14 | 
            +
            from cog import BasePredictor, Input, Path
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            sys.path.insert(0, "scripts/evaluation")
         | 
| 17 | 
            +
            from funcs import (
         | 
| 18 | 
            +
                batch_ddim_sampling,
         | 
| 19 | 
            +
                load_model_checkpoint,
         | 
| 20 | 
            +
                load_image_batch,
         | 
| 21 | 
            +
                get_filelist,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class Predictor(BasePredictor):
         | 
| 27 | 
            +
                def setup(self) -> None:
         | 
| 28 | 
            +
                    """Load the model into memory to make running multiple predictions efficient"""
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt"
         | 
| 31 | 
            +
                    config_base = "configs/inference_t2v_1024_v1.0.yaml"
         | 
| 32 | 
            +
                    ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt"
         | 
| 33 | 
            +
                    config_i2v = "configs/inference_i2v_512_v1.0.yaml"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    config_base = OmegaConf.load(config_base)
         | 
| 36 | 
            +
                    model_config_base = config_base.pop("model", OmegaConf.create())
         | 
| 37 | 
            +
                    self.model_base = instantiate_from_config(model_config_base)
         | 
| 38 | 
            +
                    self.model_base = self.model_base.cuda()
         | 
| 39 | 
            +
                    self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base)
         | 
| 40 | 
            +
                    self.model_base.eval()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    config_i2v = OmegaConf.load(config_i2v)
         | 
| 43 | 
            +
                    model_config_i2v = config_i2v.pop("model", OmegaConf.create())
         | 
| 44 | 
            +
                    self.model_i2v = instantiate_from_config(model_config_i2v)
         | 
| 45 | 
            +
                    self.model_i2v = self.model_i2v.cuda()
         | 
| 46 | 
            +
                    self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v)
         | 
| 47 | 
            +
                    self.model_i2v.eval()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def predict(
         | 
| 50 | 
            +
                    self,
         | 
| 51 | 
            +
                    task: str = Input(
         | 
| 52 | 
            +
                        description="Choose the task.",
         | 
| 53 | 
            +
                        choices=["text2video", "image2video"],
         | 
| 54 | 
            +
                        default="text2video",
         | 
| 55 | 
            +
                    ),
         | 
| 56 | 
            +
                    prompt: str = Input(
         | 
| 57 | 
            +
                        description="Prompt for video generation.",
         | 
| 58 | 
            +
                        default="A tiger walks in the forest, photorealistic, 4k, high definition.",
         | 
| 59 | 
            +
                    ),
         | 
| 60 | 
            +
                    image: Path = Input(
         | 
| 61 | 
            +
                        description="Input image for image2video task.", default=None
         | 
| 62 | 
            +
                    ),
         | 
| 63 | 
            +
                    ddim_steps: int = Input(description="Number of denoising steps.", default=50),
         | 
| 64 | 
            +
                    unconditional_guidance_scale: float = Input(
         | 
| 65 | 
            +
                        description="Classifier-free guidance scale.", default=12.0
         | 
| 66 | 
            +
                    ),
         | 
| 67 | 
            +
                    seed: int = Input(
         | 
| 68 | 
            +
                        description="Random seed. Leave blank to randomize the seed", default=None
         | 
| 69 | 
            +
                    ),
         | 
| 70 | 
            +
                    save_fps: int = Input(
         | 
| 71 | 
            +
                        description="Frame per second for the generated video.", default=10
         | 
| 72 | 
            +
                    ),
         | 
| 73 | 
            +
                ) -> Path:
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    width = 1024 if task == "text2video" else 512
         | 
| 76 | 
            +
                    height = 576 if task == "text2video" else 320
         | 
| 77 | 
            +
                    model = self.model_base if task == "text2video" else self.model_i2v
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if task == "image2video":
         | 
| 80 | 
            +
                        assert image is not None, "Please provide image for image2video generation."
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    if seed is None:
         | 
| 83 | 
            +
                        seed = int.from_bytes(os.urandom(2), "big")
         | 
| 84 | 
            +
                    print(f"Using seed: {seed}")
         | 
| 85 | 
            +
                    seed_everything(seed)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    args = argparse.Namespace(
         | 
| 88 | 
            +
                        mode="base" if task == "text2video" else "i2v",
         | 
| 89 | 
            +
                        savefps=save_fps,
         | 
| 90 | 
            +
                        n_samples=1,
         | 
| 91 | 
            +
                        ddim_steps=ddim_steps,
         | 
| 92 | 
            +
                        ddim_eta=1.0,
         | 
| 93 | 
            +
                        bs=1,
         | 
| 94 | 
            +
                        height=height,
         | 
| 95 | 
            +
                        width=width,
         | 
| 96 | 
            +
                        frames=-1,
         | 
| 97 | 
            +
                        fps=28 if task == "text2video" else 8,
         | 
| 98 | 
            +
                        unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 99 | 
            +
                        unconditional_guidance_scale_temporal=None,
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    ## latent noise shape
         | 
| 103 | 
            +
                    h, w = args.height // 8, args.width // 8
         | 
| 104 | 
            +
                    frames = model.temporal_length if args.frames < 0 else args.frames
         | 
| 105 | 
            +
                    channels = model.channels
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    batch_size = 1
         | 
| 108 | 
            +
                    noise_shape = [batch_size, channels, frames, h, w]
         | 
| 109 | 
            +
                    fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
         | 
| 110 | 
            +
                    prompts = [prompt]
         | 
| 111 | 
            +
                    text_emb = model.get_learned_conditioning(prompts)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if args.mode == "base":
         | 
| 114 | 
            +
                        cond = {"c_crossattn": [text_emb], "fps": fps}
         | 
| 115 | 
            +
                    elif args.mode == "i2v":
         | 
| 116 | 
            +
                        cond_images = load_image_batch([str(image)], (args.height, args.width))
         | 
| 117 | 
            +
                        cond_images = cond_images.to(model.device)
         | 
| 118 | 
            +
                        img_emb = model.get_image_embeds(cond_images)
         | 
| 119 | 
            +
                        imtext_cond = torch.cat([text_emb, img_emb], dim=1)
         | 
| 120 | 
            +
                        cond = {"c_crossattn": [imtext_cond], "fps": fps}
         | 
| 121 | 
            +
                    else:
         | 
| 122 | 
            +
                        raise NotImplementedError
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    ## inference
         | 
| 125 | 
            +
                    batch_samples = batch_ddim_sampling(
         | 
| 126 | 
            +
                        model,
         | 
| 127 | 
            +
                        cond,
         | 
| 128 | 
            +
                        noise_shape,
         | 
| 129 | 
            +
                        args.n_samples,
         | 
| 130 | 
            +
                        args.ddim_steps,
         | 
| 131 | 
            +
                        args.ddim_eta,
         | 
| 132 | 
            +
                        args.unconditional_guidance_scale,
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    out_path = "/tmp/output.mp4"
         | 
| 136 | 
            +
                    vid_tensor = batch_samples[0]
         | 
| 137 | 
            +
                    video = vid_tensor.detach().cpu()
         | 
| 138 | 
            +
                    video = torch.clamp(video.float(), -1.0, 1.0)
         | 
| 139 | 
            +
                    video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    frame_grids = [
         | 
| 142 | 
            +
                        torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
         | 
| 143 | 
            +
                        for framesheet in video
         | 
| 144 | 
            +
                    ]  # [3, 1*h, n*w]
         | 
| 145 | 
            +
                    grid = torch.stack(frame_grids, dim=0)  # stack in temporal dim [t, 3, n*h, w]
         | 
| 146 | 
            +
                    grid = (grid + 1.0) / 2.0
         | 
| 147 | 
            +
                    grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
         | 
| 148 | 
            +
                    torchvision.io.write_video(
         | 
| 149 | 
            +
                        out_path,
         | 
| 150 | 
            +
                        grid,
         | 
| 151 | 
            +
                        fps=args.savefps,
         | 
| 152 | 
            +
                        video_codec="h264",
         | 
| 153 | 
            +
                        options={"crf": "10"},
         | 
| 154 | 
            +
                    )
         | 
| 155 | 
            +
                    return Path(out_path)
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            decord==0.6.0
         | 
| 2 | 
            +
            einops==0.3.0
         | 
| 3 | 
            +
            imageio==2.9.0
         | 
| 4 | 
            +
            numpy==1.24.2
         | 
| 5 | 
            +
            omegaconf==2.1.1
         | 
| 6 | 
            +
            opencv_python>=4.1.2 
         | 
| 7 | 
            +
            pandas==2.0.0
         | 
| 8 | 
            +
            Pillow==9.5.0
         | 
| 9 | 
            +
            pytorch_lightning==1.8.3
         | 
| 10 | 
            +
            PyYAML==6.0
         | 
| 11 | 
            +
            setuptools==65.6.3
         | 
| 12 | 
            +
            torch==2.0.0
         | 
| 13 | 
            +
            torchvision>=0.7.0
         | 
| 14 | 
            +
            tqdm==4.65.0
         | 
| 15 | 
            +
            transformers==4.25.1
         | 
| 16 | 
            +
            moviepy>=1.0.3 
         | 
| 17 | 
            +
            av
         | 
| 18 | 
            +
            xformers
         | 
| 19 | 
            +
            gradio
         | 
| 20 | 
            +
            timm
         | 
| 21 | 
            +
            scikit-learn 
         | 
| 22 | 
            +
            open_clip_torch==2.22.0
         | 
| 23 | 
            +
            kornia
         | 
| 24 | 
            +
            sk-video>=1.1.10 
         |