File size: 2,811 Bytes
2c1f0c3 95cc45b 35920a6 2c1f0c3 95cc45b 448a859 95cc45b 0593b2c c5c043a 95cc45b 448a859 95cc45b 448a859 95cc45b 0593b2c 95cc45b 448a859 7acc91c 448a859 7acc91c 95cc45b 35920a6 0593b2c 7acc91c 35920a6 95cc45b 448a859 254aacb 2c1f0c3 35920a6 2c1f0c3 95cc45b 2c1f0c3 35920a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
import torch
import torchvision
from diffusers import I2VGenXLPipeline, DiffusionPipeline
from torchvision.transforms.functional import to_tensor
from PIL import Image
if gr.NO_RELOAD:
n_steps = 50
high_noise_frac = 0.8
negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
generator = torch.manual_seed(8888)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
base.to("cuda")
refiner.to("cuda")
pipeline.to("cuda")
base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
def generate(prompt: str, progress=gr.Progress()):
progress((0, 100), desc="Starting..")
image = base(
prompt=prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
callback_on_step_end=lambda p, s, t, d: progress((s, 100), desc="Generating first frame..."),
).images[0]
image = refiner(
prompt=prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
callback_on_step_end=lambda p, s, t, d: progress((s+40, 100), desc="Refining first frame..."),
).images[0]
image = to_tensor(image)
frames: list[Image.Image] = pipeline(
prompt=prompt,
image=image,
num_inference_steps=50,
negative_prompt=negative_prompt,
guidance_scale=9.0,
generator=generator,
decode_chunk_size=10,
callback_on_step_end=lambda p, s, t, d: progress((s+50, 100), desc="Generating video..."),
).frames[0]
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
frames = torch.stack(frames)
torchvision.io.write_video("video.mp4", frames, fps=8)
return "video.mp4"
app = gr.Interface(
fn=generate,
inputs=["text"],
outputs=gr.Video()
)
if __name__ == "__main__":
app.launch()
|