multimodalart's picture
Update app.py
bcdfcef verified
raw
history blame
3.76 kB
import gradio as gr
import numpy as np
import torch
import spaces
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers.utils import export_to_gif
from huggingface_hub import hf_hub_download
from PIL import Image
import uuid
import random
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
def split_image(input_image, num_splits=4):
# Create a list to store the output images
output_images = []
# Split the image into four 256x256 sections
for i in range(num_splits):
left = i * 320
right = (i + 1) * 320
box = (left, 0, right, 320)
output_images.append(input_image.crop(box))
return output_images
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch_dtype)
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU
def infer(prompt, seed=1, randomize_seed=False, num_inference_steps=28):
prompt_template = f"A side by side 4 frame image showing high quality consecutive stills from a looped gif animation moving from left to right. The scene has motion. The stills are of {prompt}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt_template,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=generator,
height=320,
width=1280
).images[0]
gif_name = f"{uuid.uuid4().hex}-flux.gif"
export_to_gif(split_image(image, 4), gif_name, fps=4)
return gif_name, image, seed
examples = [
"a cute cat raising a sign that reads \"Flux does Video?\"",
"Chris Rock eating pizza",
"A flying saucer over the white house",
]
css="""
#col-container {
margin: 0 auto;
max-width: 640px;
}
#strip{
max-height: 160px
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# FLUX Gif Animations
Generate gifs with FLUX [dev]. Concept idea by [fofr](https://x.com/fofrAI). Diffusers implementation by [Dhruv](_DhruvNair_)
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
result_full = gr.Image(label="Gif Strip", elem_id="strip")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=32,
step=1,
value=28,
)
gr.Examples(
examples = examples,
inputs = [prompt],
outputs = [result, result_full, seed],
fn=infer,
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, seed, randomize_seed, num_inference_steps],
outputs = [result, result_full, seed]
)
demo.queue().launch()