multimodalart HF Staff commited on
Commit
2af7c18
·
verified ·
1 Parent(s): 214ca22

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from diffusers import FluxPipeline, FluxTransformer2DModel
5
+ from PIL import Image
6
+ from diffusers.utils import export_to_gif
7
+ import uuid
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ if torch.cuda.is_available():
12
+ torch_dtype = torch.bfloat16
13
+ else:
14
+ torch_dtype = torch.float32
15
+
16
+ def split_image(input_image, num_splits=4):
17
+ # Create a list to store the output images
18
+ output_images = []
19
+
20
+ # Split the image into four 256x256 sections
21
+ for i in range(num_splits):
22
+ left = i * 256
23
+ right = (i + 1) * 256
24
+ box = (left, 0, right, 256)
25
+ output_images.append(input_image.crop(box))
26
+
27
+ return output_images
28
+
29
+ pipe = FluxPipeline.from_pretrained(
30
+ "black-forest-labs/FLUX.1-schnell",
31
+ torch_dtype=torch_dtype
32
+ )
33
+ pipe.to(device)
34
+
35
+ @spaces.GPU
36
+ def infer(prompt, seed, randomize_seed, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
37
+ prompt_template = f"A side by side 4 frame image showing consecutive stills from a looped gif moving from left to right. The gif is {prompt}"
38
+ if randomize_seed:
39
+ seed = random.randint(0, MAX_SEED)
40
+
41
+ generator = torch.Generator().manual_seed(seed)
42
+
43
+ image = pipe(
44
+ prompt=prompt,
45
+ num_inference_steps=num_inference_steps,
46
+ num_images_per_prompt=1,
47
+ generator=torch.Generator(device).manual_seed(seed),
48
+ height=height,
49
+ width=width
50
+ ).images[0]
51
+ gif_name = f"{uuid.uuid4().hex}-flux.gif"
52
+ export_to_gif(split_image(image, 4), gif_name, fps=4)
53
+ return gif_name, seed
54
+
55
+ examples = [
56
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
57
+ "An astronaut riding a green horse",
58
+ "A delicious ceviche cheesecake slice",
59
+ ]
60
+
61
+ css="""
62
+ #col-container {
63
+ margin: 0 auto;
64
+ max-width: 640px;
65
+ }
66
+ """
67
+
68
+ with gr.Blocks(css=css) as demo:
69
+
70
+ with gr.Column(elem_id="col-container"):
71
+ gr.Markdown(f"""
72
+ # FLUX.1 Schnell Animations
73
+ Generate gifs with
74
+ """)
75
+
76
+ with gr.Row():
77
+
78
+ prompt = gr.Text(
79
+ label="Prompt",
80
+ show_label=False,
81
+ max_lines=1,
82
+ placeholder="Enter your prompt",
83
+ container=False,
84
+ )
85
+
86
+ run_button = gr.Button("Run", scale=0)
87
+
88
+ result = gr.Image(label="Result", show_label=False)
89
+
90
+ with gr.Accordion("Advanced Settings", open=False):
91
+
92
+ seed = gr.Slider(
93
+ label="Seed",
94
+ minimum=0,
95
+ maximum=MAX_SEED,
96
+ step=1,
97
+ value=0,
98
+ )
99
+
100
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
+
102
+ num_inference_steps = gr.Slider(
103
+ label="Number of inference steps",
104
+ minimum=1,
105
+ maximum=12,
106
+ step=1,
107
+ value=4,
108
+ )
109
+
110
+
111
+ gr.Examples(
112
+ examples = examples,
113
+ inputs = [prompt]
114
+ )
115
+ gr.on(
116
+ trigger=[run_button.click, prompt.submit],
117
+ fn = infer,
118
+ inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
119
+ outputs = [result, seed]
120
+ )
121
+
122
+ demo.queue().launch()