File size: 2,811 Bytes
2c1f0c3
 
 
95cc45b
 
35920a6
2c1f0c3
95cc45b
448a859
95cc45b
 
 
0593b2c
 
c5c043a
95cc45b
 
 
 
 
 
448a859
 
 
 
 
 
 
 
95cc45b
448a859
 
 
 
 
 
 
 
95cc45b
0593b2c
 
95cc45b
 
 
448a859
 
7acc91c
448a859
 
 
 
 
 
7acc91c
95cc45b
 
 
35920a6
 
 
 
 
 
0593b2c
7acc91c
35920a6
95cc45b
 
448a859
254aacb
2c1f0c3
35920a6
2c1f0c3
95cc45b
2c1f0c3
35920a6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import torch
import torchvision
from diffusers import I2VGenXLPipeline, DiffusionPipeline
from torchvision.transforms.functional import to_tensor
from PIL import Image

if gr.NO_RELOAD:
    n_steps = 50
    high_noise_frac = 0.8
    negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
    generator = torch.manual_seed(8888)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    base = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        torch_dtype=torch.float16, 
        variant="fp16", 
        use_safetensors=True,
    )
    refiner = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-refiner-1.0",
        text_encoder_2=base.text_encoder_2,
        vae=base.vae,
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    )
    pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")

    base.to("cuda")
    refiner.to("cuda")
    pipeline.to("cuda")

    base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
    refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
    pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)

def generate(prompt: str, progress=gr.Progress()):
    progress((0, 100), desc="Starting..")
    image = base(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
        callback_on_step_end=lambda p, s, t, d: progress((s, 100), desc="Generating first frame..."),
    ).images[0]
    image = refiner(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
         callback_on_step_end=lambda p, s, t, d: progress((s+40, 100), desc="Refining first frame..."),
    ).images[0]
    image = to_tensor(image)
    frames: list[Image.Image] = pipeline(
        prompt=prompt,
        image=image,
        num_inference_steps=50,
        negative_prompt=negative_prompt,
        guidance_scale=9.0,
        generator=generator,
        decode_chunk_size=10,
        callback_on_step_end=lambda p, s, t, d: progress((s+50, 100), desc="Generating video..."),
    ).frames[0]
    frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
    frames = torch.stack(frames)
    torchvision.io.write_video("video.mp4", frames, fps=8)
    return "video.mp4"

app = gr.Interface(
    fn=generate,
    inputs=["text"],
    outputs=gr.Video()
)

if __name__ == "__main__":
    app.launch()