File size: 9,362 Bytes
5d32408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse, os, sys, glob, yaml, math, random
import datetime, time
import numpy as np
from omegaconf import OmegaConf
from collections import OrderedDict
from tqdm import trange, tqdm
from einops import repeat
from einops import rearrange, repeat
from functools import partial
import torch
from pytorch_lightning import seed_everything

from .funcs import load_model_checkpoint, load_image_batch, get_filelist, save_videos
from .funcs import batch_ddim_sampling
from .utils import instantiate_from_config

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything")
    parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}")
    parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
    parser.add_argument("--config", type=str, help="config (yaml) path")
    parser.add_argument("--savefps", type=str, default=10, help="video fps to generate")
    parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt", )
    parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM", )
    parser.add_argument("--ddim_eta", type=float, default=1.0,
                        help="eta for ddim sampling (0.0 yields deterministic sampling)", )
    parser.add_argument("--bs", type=int, default=1, help="batch size for inference")
    parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
    parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
    parser.add_argument("--frames", type=int, default=-1, help="frames num to inference")
    parser.add_argument("--fps", type=int, default=24)
    parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0,
                        help="prompt classifier-free guidance")
    parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None,
                        help="temporal consistency guidance")
    ## for conditional i2v only
    # parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input")
    return parser


class VideoCrafterPipeline():
    def __init__(self, arg_list, device, rank: int = 0, gpu_num: int = 1):
        """
        Initialize the pipeline of videocrafter.
        It is always on one GPU.
        Args:
            arg_list: The parameters needed for the model.
            device:
            rank:
            gpu_num:
        """
        parser = get_parser()
        self.args = parser.parse_args(args=arg_list)

        self.gpu_no, self.gpu_num = rank, gpu_num
        _dict = {'model': {'target': 'lvdm.models.ddpm3d.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'timesteps': 1000, 'first_stage_key': 'video', 'cond_stage_key': 'caption', 'cond_stage_trainable': False, 'conditioning_key': 'crossattn', 'image_size': [40, 64], 'channels': 4, 'scale_by_std': False, 'scale_factor': 0.18215, 'use_ema': False, 'uncond_type': 'empty_seq', 'use_scale': True, 'scale_b': 0.7, 'unet_config': {'target': 'lvdm.modules.networks.openaimodel3d.UNetModel', 'params': {'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'transformer_depth': 1, 'context_dim': 1024, 'use_linear': True, 'use_checkpoint': True, 'temporal_conv': True, 'temporal_attention': True, 'temporal_selfatt_only': True, 'use_relative_position': False, 'use_causal_attention': False, 'temporal_length': 16, 'addition_attention': True, 'fps_cond': True}}, 'first_stage_config': {'target': 'lvdm.models.autoencoder.AutoencoderKL', 'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 512, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}, 'lossconfig': {'target': 'torch.nn.Identity'}}}, 'cond_stage_config': {'target': 'lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder', 'params': {'freeze': True, 'layer': 'penultimate'}}}}}

        config = OmegaConf.create(_dict)
        #config = OmegaConf.load(self.args.config)

        # data_config = config.pop("data", OmegaConf.create())
        model_config = config.pop("model", OmegaConf.create())
        model = instantiate_from_config(model_config)
        model = model.cuda(self.gpu_no)
        print("About to load model")
        assert os.path.exists(self.args.ckpt_path), f"Error: checkpoint [{self.args.ckpt_path}] Not Found!"
        self.model = load_model_checkpoint(model, self.args.ckpt_path)
        self.model.eval()

    def run_inference(self, prompt, video_length, height, width, **kwargs):
        """
        https://github.com/AILab-CVC/VideoCrafter
        Generate video from the provided text prompt.
        Args:
            prompt: The provided text prompt.
            video_length: The length (num of frames) of the generated video.
            height: The height of the video frame.
            width: The width of the video frame.
            **kwargs:

        Returns:
            The generated video represented as tensor with shape (1, 1, channels, height, width, num of frames)

        """
        ## step 1: model config
        ## -----------------------------------------------------------------
        ## sample shape
        assert (self.args.height % 16 == 0) and (
                self.args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
        ## latent noise shape
        h, w = height // 8, width // 8
        frames = video_length
        channels = self.model.channels

        ## step 2: load data
        ## -----------------------------------------------------------------
        prompt_list = [prompt]
        num_samples = len(prompt_list)
        # filename_list = [f"{id + 1:04d}" for id in range(num_samples)]

        gpu_num = self.gpu_num
        gpu_no = self.gpu_no
        samples_split = num_samples // gpu_num
        residual_tail = num_samples % gpu_num
        print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.')
        indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
        if gpu_no == 0 and residual_tail != 0:
            indices = indices + list(range(num_samples - residual_tail, num_samples))
        prompt_list_rank = [prompt_list[i] for i in indices]

        # # conditional input
        # if self.args.mode == "i2v":
        #     ## each video or frames dir per prompt
        #     cond_inputs = get_filelist(self.args.cond_input, ext='[mpj][pn][4gj]')  # '[mpj][pn][4gj]'
        #     assert len(
        #         cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!"
        #     filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)]
        #     cond_inputs_rank = [cond_inputs[i] for i in indices]

        # filename_list_rank = [filename_list[i] for i in indices]

        ## step 3: run over samples
        ## -----------------------------------------------------------------
        # start = time.time()
        n_rounds = len(prompt_list_rank) // self.args.bs
        n_rounds = n_rounds + 1 if len(prompt_list_rank) % self.args.bs != 0 else n_rounds
        for idx in range(0, n_rounds):
            print(f'[rank:{gpu_no}] batch-{idx + 1} ({self.args.bs})x{self.args.n_samples} ...')
            idx_s = idx * self.args.bs
            idx_e = min(idx_s + self.args.bs, len(prompt_list_rank))
            batch_size = idx_e - idx_s
            # filenames = filename_list_rank[idx_s:idx_e]
            noise_shape = [batch_size, channels, frames, h, w]
            fps = torch.tensor([self.args.fps] * batch_size).to(self.model.device).long()

            prompts = prompt_list_rank[idx_s:idx_e]
            if isinstance(prompts, str):
                prompts = [prompts]
            # prompts = batch_size * [""]
            text_emb = self.model.get_learned_conditioning(prompts)

            if self.args.mode == 'base':
                cond = {"c_crossattn": [text_emb], "fps": fps}
            # elif self.args.mode == 'i2v':
            #     # cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device)
            #     cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (self.args.height, self.args.width))
            #     cond_images = cond_images.to(self.model.device)
            #     img_emb = self.model.get_image_embeds(cond_images)
            #     imtext_cond = torch.cat([text_emb, img_emb], dim=1)
            #     cond = {"c_crossattn": [imtext_cond], "fps": fps}
            else:
                raise NotImplementedError

            ## inference
            batch_samples = batch_ddim_sampling(self.model, cond, noise_shape, self.args.n_samples,
                                                self.args.ddim_steps,
                                                self.args.ddim_eta,
                                                self.args.unconditional_guidance_scale, **kwargs)
            return batch_samples