mrcuddle commited on
Commit
fe77a8e
·
verified ·
1 Parent(s): bda76bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -37
app.py CHANGED
@@ -2,51 +2,63 @@ import gradio as gr
2
  import torch
3
  from diffusers import I2VGenXLPipeline
4
  from diffusers.utils import export_to_gif, load_image
5
- import spaces
6
 
7
- # Initialize the pipeline
8
- pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
9
- pipeline.enable_model_cpu_offload()
10
 
11
- @spaces.GPU(duration=240)
12
- def generate_gif(image, prompt, negative_prompt, num_inference_steps, guidance_scale, seed):
13
- # Load the image
14
- image = load_image(image).convert("RGB")
15
 
 
16
  # Set the generator seed
17
- generator = torch.manual_seed(seed)
18
-
19
- # Generate the frames
20
- frames = pipeline(
21
- prompt=prompt,
22
- image=image,
23
- num_inference_steps=num_inference_steps,
24
- negative_prompt=negative_prompt,
25
- guidance_scale=guidance_scale,
26
- generator=generator
27
- ).frames[0]
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Export to GIF
30
- gif_path = "i2v.gif"
31
- export_to_gif(frames, gif_path)
 
32
 
33
  return gif_path
34
 
35
- # Create the Gradio interface
36
- iface = gr.Interface(
37
- fn=generate_gif,
38
- inputs=[
39
- gr.Image(type="filepath", label="Input Image"),
40
- gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
41
- gr.Textbox(lines=2, placeholder="Enter your negative prompt here...", label="Negative Prompt"),
42
- gr.Slider(1, 50, step=1, value=15, label="Number of Inference Steps"),
43
- gr.Slider(1, 10, step=0.1, value=9.0, label="Guidance Scale"),
44
- gr.Number(label="Seed", value=8888)
45
- ],
46
- outputs=gr.Video(label="Generated GIF"),
47
- title="I2VGen-XL GIF Generator",
48
- description="Generate a GIF from an image and a prompt using the I2VGen-XL model."
49
- )
 
 
50
 
51
  # Launch the interface
52
- iface.launch()
 
2
  import torch
3
  from diffusers import I2VGenXLPipeline
4
  from diffusers.utils import export_to_gif, load_image
5
+ import tempfile
6
 
7
+ # Check if CUDA is available and set the device
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
+ # Initialize the pipeline with CUDA support
11
+ pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
12
+ pipeline.to(device)
 
13
 
14
+ def generate_gif(prompt, image, negative_prompt, num_inference_steps, guidance_scale, seed):
15
  # Set the generator seed
16
+ generator = torch.Generator(device=device).manual_seed(seed)
17
+
18
+ # Check if an image is provided
19
+ if image is not None:
20
+ image = load_image(image).convert("RGB")
21
+ frames = pipeline(
22
+ prompt=prompt,
23
+ image=image,
24
+ num_inference_steps=num_inference_steps,
25
+ negative_prompt=negative_prompt,
26
+ guidance_scale=guidance_scale,
27
+ generator=generator
28
+ ).frames[0]
29
+ else:
30
+ frames = pipeline(
31
+ prompt=prompt,
32
+ num_inference_steps=num_inference_steps,
33
+ negative_prompt=negative_prompt,
34
+ guidance_scale=guidance_scale,
35
+ generator=generator
36
+ ).frames[0]
37
 
38
  # Export to GIF
39
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".gif") as tmp_gif:
40
+ gif_path = tmp_gif.name
41
+ export_to_gif(frames, gif_path)
42
 
43
  return gif_path
44
 
45
+ # Create the Gradio interface with tabs
46
+ with gr.Tabs() as demo:
47
+ with gr.TabItem("Generate from Text or Image"):
48
+ interface = gr.Interface(
49
+ fn=generate_gif,
50
+ inputs=[
51
+ gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
52
+ gr.Image(type="filepath", label="Input Image (optional)"),
53
+ gr.Textbox(lines=2, placeholder="Enter your negative prompt here...", label="Negative Prompt"),
54
+ gr.Slider(1, 100, step=1, value=50, label="Number of Inference Steps"),
55
+ gr.Slider(1, 20, step=0.1, value=9.0, label="Guidance Scale"),
56
+ gr.Number(label="Seed", value=8888)
57
+ ],
58
+ outputs=gr.Video(label="Generated GIF"),
59
+ title="I2VGen-XL GIF Generator",
60
+ description="Generate a GIF from a text prompt and/or an image using the I2VGen-XL model."
61
+ )
62
 
63
  # Launch the interface
64
+ demo.launch()