File size: 8,326 Bytes
07e35a2
4afdd4d
 
 
 
 
44b947d
4afdd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a78d42
44b947d
 
fd512db
 
44b947d
6a78d42
 
44b947d
fd512db
5e9229f
83867cc
 
 
 
 
 
 
 
 
 
 
 
 
fd512db
4afdd4d
 
28fb111
4afdd4d
28fb111
 
4afdd4d
28fb111
4afdd4d
28fb111
be0fea4
 
5e9229f
 
be0fea4
 
5e9229f
be0fea4
 
 
 
5e9229f
 
feda343
be0fea4
 
5e9229f
44b947d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07e35a2
44b947d
 
 
07e35a2
 
 
 
 
 
 
85c0e7a
 
 
07e35a2
 
 
 
 
4afdd4d
07e35a2
85c0e7a
4afdd4d
07e35a2
85c0e7a
07e35a2
ffc99a9
44b947d
07e35a2
 
 
85c0e7a
07e35a2
915ecc0
85c0e7a
 
b5d93b2
07e35a2
b5d93b2
ffc99a9
07e35a2
 
85c0e7a
07e35a2
 
 
 
 
83867cc
07e35a2
 
 
 
eb70557
 
4afdd4d
85c0e7a
eec49a2
44b947d
 
eec49a2
4afdd4d
44b947d
4afdd4d
 
 
28fb111
4afdd4d
 
 
 
07e35a2
4afdd4d
 
 
 
 
87bc4ad
4de476c
44b947d
4afdd4d
 
 
 
 
44b947d
4afdd4d
 
44b947d
4afdd4d
44b947d
4afdd4d
 
 
 
53b9c8b
81ee7a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# -*- coding: utf-8 -*-
import gradio as gr
import os
import sys
import random
import time
import uuid
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
from einops import repeat
import torchvision.transforms as transforms
from utils.utils import instantiate_from_config
sys.path.insert(0, "scripts/evaluation")
from funcs import (
    batch_ddim_sampling,
    load_model_checkpoint,
    get_latent_z,
    save_videos
)
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from diffusers import StableDiffusionXLPipeline

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

# ๋‚˜๋จธ์ง€ ์ฝ”๋“œ๋Š” ๊ทธ๋Œ€๋กœ ์œ ์ง€...

def is_tensor(x):
    return torch.is_tensor(x)

# ๋ฒˆ์—ญ ๋ชจ๋ธ ๋กœ๋“œ (PyTorch ๋ฒ„์ „ ์‚ฌ์šฉ)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=0 if torch.cuda.is_available() else -1, framework="pt")

# ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionXLPipeline.from_pretrained(
    "SG161222/RealVisXL_V4.0",
    torch_dtype=torch.float32,
    use_safetensors=True,
    add_watermarker=False
).to(device)


os.environ['KERAS_BACKEND'] = 'pytorch'

def download_model():
    REPO_ID = 'Doubiiu/DynamiCrafter_1024'
    filename_list = ['model.ckpt']
    if not os.path.exists('./checkpoints/dynamicrafter_1024_v1/'):
        os.makedirs('./checkpoints/dynamicrafter_1024_v1/')
    for filename in filename_list:
        local_file = os.path.join('./checkpoints/dynamicrafter_1024_v1/', filename)
        if not os.path.exists(local_file):
            hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)

download_model()
ckpt_path = 'checkpoints/dynamicrafter_1024_v1/model.ckpt'
config_file = 'configs/inference_1024_v1.0.yaml'
config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint'] = True
model = instantiate_from_config(model_config)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path)
model.eval()

# ๋ชจ๋ธ์„ DataParallel๋กœ ๊ฐ์‹ธ์„œ ์—ฌ๋Ÿฌ GPU์—์„œ ์‹คํ–‰ ๊ฐ€๋Šฅํ•˜๊ฒŒ ์„ค์ •
#model = torch.nn.DataParallel(model)
model = model.cuda()



def generate_image(prompt: str):
    # ํ•œ๊ธ€ ์ž…๋ ฅ ๊ฐ์ง€ ๋ฐ ๋ฒˆ์—ญ
    if any('\uac00' <= char <= '\ud7a3' for char in prompt):
        translated = translator(prompt, max_length=512)
        prompt = translated[0]['translation_text']
    
    # Hi-res์™€ 3840x2160 ์Šคํƒ€์ผ ์ ์šฉ
    prompt = f"hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic"
    
    # ๊ณ ์ •๋œ ์„ค์ •๊ฐ’
    negative_prompt = "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly, (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, disgusting, amputation"
    width = 1024
    height = 576
    guidance_scale = 6
    num_inference_steps = 100
    seed = random.randint(0, 2**32 - 1)
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
    ).images[0]
    
    unique_name = str(uuid.uuid4()) + ".png"
    image.save(unique_name)
    return unique_name

# @spaces.GPU(duration=300, gpu_type="l40s")
def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
    try:
        image_path = generate_image(prompt)
        image = torchvision.io.read_image(image_path).float() / 255.0
        
        if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
            translated = translator(prompt, max_length=512)
            prompt = translated[0]['translation_text']
        
        resolution = (576, 1024)
        save_fps = 8
        seed_everything(seed)
        transform = transforms.Compose([
            transforms.Resize(resolution, antialias=True),
        ])
        
        print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
        start = time.time()
        if steps > 60:
            steps = 60 

        batch_size = 1
        channels = model.model.out_channels  # ์ˆ˜์ •๋œ ๋ถ€๋ถ„

        with torch.no_grad(), torch.cuda.amp.autocast():
            text_emb = model.get_learned_conditioning([prompt])  
            
            img_tensor = image.to(torch.cuda.current_device())
            img_tensor = (img_tensor - 0.5) * 2
            image_tensor_resized = transform(img_tensor)
            videos = image_tensor_resized.unsqueeze(0)
            
            z = get_latent_z(model, videos.unsqueeze(2)) 
            img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
        
            cond_images = model.embedder(img_tensor.unsqueeze(0)) 
            img_emb = model.image_proj_model(cond_images) 
        
            imtext_cond = torch.cat([text_emb, img_emb], dim=1)
        
            fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
            cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
            
            batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
        
            video_path = './output.mp4'
            save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
        
        return video_path

    except Exception as e:
        print(f"Error occurred: {e}")
        return None
    finally:
        torch.cuda.empty_cache()



i2v_examples = [
    ['์šฐ์ฃผ์ธ ๋ณต์žฅ์œผ๋กœ ๊ธฐํƒ€๋ฅผ ์น˜๋Š” ๋‚จ์ž', 30, 7.5, 1.0, 6, 123, 64],
    ['time-lapse of a blooming flower with leaves and a stem', 30, 7.5, 1.0, 10, 123, 64],
]

css = """#output_vid {max-width: 1024px; max-height: 576px}"""

with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
    
    with gr.Tab(label='ImageAnimation_576x1024'):
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        i2v_input_text = gr.Textbox(label='Prompts (ํ•œ๊ธ€ ์ž…๋ ฅ ๊ฐ€๋Šฅ)')
                    with gr.Row():
                        i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
                        i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
                        i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
                    with gr.Row():
                        i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
                        i2v_motion = gr.Slider(minimum=5, maximum=20, step=1, elem_id="i2v_motion", label="FPS", value=10)
                        i2v_frames = gr.Slider(minimum=16, maximum=128, step=16, elem_id="i2v_frames", label="Number of frames", value=64)
                    i2v_end_btn = gr.Button("Generate")
                with gr.Row():
                    i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)

            gr.Examples(examples=i2v_examples,
                        inputs=[i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_frames],
                        outputs=[i2v_output_video],
                        fn = infer,
                        cache_examples=False
            )
        i2v_end_btn.click(inputs=[i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_frames],
                        outputs=[i2v_output_video],
                        fn = infer
        )

dynamicrafter_iface.launch()