Ryukijano commited on
Commit
fabdc5a
·
verified ·
1 Parent(s): 1b66a10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -34
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  torch.backends.cuda.matmul.allow_tf32 = True
3
  torch.backends.cudnn.allow_tf32 = True
4
  import gradio as gr
@@ -25,12 +26,17 @@ pipe = FluxWithCFGPipeline.from_pretrained(
25
  )
26
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
27
  pipe.to("cuda")
28
- pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
 
 
 
 
29
  pipe.set_adapters(["better"], adapter_weights=[1.0])
30
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
31
  pipe.unload_lora_weights()
32
 
33
- pipe.unet.to(memory_format=torch.channels_last)
 
34
  pipe.vae.to(memory_format=torch.channels_last)
35
 
36
  pipe.enable_xformers_memory_efficient_attention()
@@ -39,7 +45,15 @@ torch.cuda.empty_cache()
39
 
40
  # Inference function
41
  @spaces.GPU(duration=25)
42
- def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
43
  if randomize_seed:
44
  seed = random.randint(0, MAX_SEED)
45
  generator = torch.Generator().manual_seed(int(float(seed)))
@@ -47,9 +61,15 @@ def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
47
  start_time = time.time()
48
 
49
  # Initialize static inputs for CUDA graph
50
- static_latents = torch.randn((1, 4, height // 8, width // 8), dtype=dtype, device="cuda")
51
- static_prompt_embeds = torch.randn((2, 77, 768), dtype=dtype, device="cuda") # Adjust dimensions as needed
52
- static_pooled_prompt_embeds = torch.randn((2, 768), dtype=dtype, device="cuda") # Adjust dimensions as needed
 
 
 
 
 
 
53
  static_text_ids = torch.tensor([[[1, 2, 3]]], dtype=torch.int32, device="cuda")
54
  static_latent_image_ids = torch.tensor([1], dtype=torch.int64, device="cuda")
55
  static_timestep = torch.tensor([999], dtype=dtype, device="cuda")
@@ -86,11 +106,22 @@ def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
86
  joint_attention_kwargs=pipe.joint_attention_kwargs,
87
  return_dict=False,
88
  )[0]
89
- static_latents_out = pipe.scheduler.step(static_noise_pred, static_timestep, static_latents, return_dict=False)[0]
90
- static_output = pipe._decode_latents_to_image(static_latents_out, height, width, "pil")
 
 
 
 
91
 
92
  # Graph-based generation function
93
- def generate_with_graph(latents, prompt_embeds, pooled_prompt_embeds, text_ids, latent_image_ids, timestep):
 
 
 
 
 
 
 
94
  static_latents.copy_(latents)
95
  static_prompt_embeds.copy_(prompt_embeds)
96
  static_pooled_prompt_embeds.copy_(pooled_prompt_embeds)
@@ -101,15 +132,15 @@ def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
101
  return static_output
102
 
103
  # Only generate the last image in the sequence
104
- img = pipe.generate_images(
105
- prompt=prompt,
106
- width=width,
107
- height=height,
108
- num_inference_steps=num_inference_steps,
109
- generator=generator,
110
- generate_with_graph=generate_with_graph
111
- )
112
- latency = f"Latency: {(time.time()-start_time):.2f} seconds"
113
  return img, seed, latency
114
 
115
  # Example prompts
@@ -127,12 +158,18 @@ examples = [
127
  with gr.Blocks() as demo:
128
  with gr.Column(elem_id="app-container"):
129
  gr.Markdown("# 🎨 Realtime FLUX Image Generator")
130
- gr.Markdown("Generate stunning images in real-time with Modified Flux.Schnell pipeline.")
131
- gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
 
 
 
 
132
 
133
  with gr.Row():
134
  with gr.Column(scale=2.5):
135
- result = gr.Image(label="Generated Image", show_label=False, interactive=False)
 
 
136
  with gr.Column(scale=1):
137
  prompt = gr.Text(
138
  label="Prompt",
@@ -146,15 +183,39 @@ with gr.Blocks() as demo:
146
 
147
  with gr.Column("Advanced Options"):
148
  with gr.Row():
149
- realtime = gr.Checkbox(label="Realtime Toggler", info="If TRUE then uses more GPU but create image in realtime.", value=False)
 
 
 
 
150
  latency = gr.Text(label="Latency")
151
  with gr.Row():
152
  seed = gr.Number(label="Seed", value=42)
153
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
154
  with gr.Row():
155
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
156
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
157
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  with gr.Row():
160
  gr.Markdown("### 🌟 Inspiration Gallery")
@@ -164,7 +225,7 @@ with gr.Blocks() as demo:
164
  fn=generate_image,
165
  inputs=[prompt],
166
  outputs=[result, seed, latency],
167
- cache_examples="lazy"
168
  )
169
 
170
  enhanceBtn.click(
@@ -173,7 +234,7 @@ with gr.Blocks() as demo:
173
  outputs=[result, seed, latency],
174
  show_progress="full",
175
  queue=False,
176
- concurrency_limit=None
177
  )
178
 
179
  generateBtn.click(
@@ -182,13 +243,13 @@ with gr.Blocks() as demo:
182
  outputs=[result, seed, latency],
183
  show_progress="full",
184
  api_name="RealtimeFlux",
185
- queue=False
186
  )
187
 
188
  def update_ui(realtime_enabled):
189
  return {
190
  prompt: gr.update(interactive=True),
191
- generateBtn: gr.update(visible=not realtime_enabled)
192
  }
193
 
194
  realtime.change(
@@ -196,7 +257,7 @@ with gr.Blocks() as demo:
196
  inputs=[realtime],
197
  outputs=[prompt, generateBtn],
198
  queue=False,
199
- concurrency_limit=None
200
  )
201
 
202
  async def realtime_generation(*args):
@@ -211,18 +272,26 @@ with gr.Blocks() as demo:
211
  outputs=[result, seed, latency],
212
  show_progress="full",
213
  queue=False,
214
- concurrency_limit=None
215
  )
216
 
217
  for component in [prompt, width, height, num_inference_steps]:
218
  component.input(
219
  fn=realtime_generation,
220
- inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
 
 
 
 
 
 
 
 
221
  outputs=[result, seed, latency],
222
  show_progress="hidden",
223
  trigger_mode="always_last",
224
  queue=True,
225
- concurrency_limit=None
226
  )
227
 
228
  # Launch the app
 
1
  import torch
2
+
3
  torch.backends.cuda.matmul.allow_tf32 = True
4
  torch.backends.cudnn.allow_tf32 = True
5
  import gradio as gr
 
26
  )
27
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
28
  pipe.to("cuda")
29
+ pipe.load_lora_weights(
30
+ "hugovntr/flux-schnell-realism",
31
+ weight_name="schnell-realism_v2.3.safetensors",
32
+ adapter_name="better",
33
+ )
34
  pipe.set_adapters(["better"], adapter_weights=[1.0])
35
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
36
  pipe.unload_lora_weights()
37
 
38
+ # Corrected: Access 'transformer' instead of 'unet'
39
+ pipe.transformer.to(memory_format=torch.channels_last)
40
  pipe.vae.to(memory_format=torch.channels_last)
41
 
42
  pipe.enable_xformers_memory_efficient_attention()
 
45
 
46
  # Inference function
47
  @spaces.GPU(duration=25)
48
+ def generate_image(
49
+ prompt,
50
+ seed=24,
51
+ width=DEFAULT_WIDTH,
52
+ height=DEFAULT_HEIGHT,
53
+ randomize_seed=False,
54
+ num_inference_steps=2,
55
+ progress=gr.Progress(track_tqdm=True),
56
+ ):
57
  if randomize_seed:
58
  seed = random.randint(0, MAX_SEED)
59
  generator = torch.Generator().manual_seed(int(float(seed)))
 
61
  start_time = time.time()
62
 
63
  # Initialize static inputs for CUDA graph
64
+ static_latents = torch.randn(
65
+ (1, 4, height // 8, width // 8), dtype=dtype, device="cuda"
66
+ )
67
+ static_prompt_embeds = torch.randn(
68
+ (2, 77, 768), dtype=dtype, device="cuda"
69
+ ) # Adjust dimensions as needed
70
+ static_pooled_prompt_embeds = torch.randn(
71
+ (2, 768), dtype=dtype, device="cuda"
72
+ ) # Adjust dimensions as needed
73
  static_text_ids = torch.tensor([[[1, 2, 3]]], dtype=torch.int32, device="cuda")
74
  static_latent_image_ids = torch.tensor([1], dtype=torch.int64, device="cuda")
75
  static_timestep = torch.tensor([999], dtype=dtype, device="cuda")
 
106
  joint_attention_kwargs=pipe.joint_attention_kwargs,
107
  return_dict=False,
108
  )[0]
109
+ static_latents_out = pipe.scheduler.step(
110
+ static_noise_pred, static_timestep, static_latents, return_dict=False
111
+ )[0]
112
+ static_output = pipe._decode_latents_to_image(
113
+ static_latents_out, height, width, "pil"
114
+ )
115
 
116
  # Graph-based generation function
117
+ def generate_with_graph(
118
+ latents,
119
+ prompt_embeds,
120
+ pooled_prompt_embeds,
121
+ text_ids,
122
+ latent_image_ids,
123
+ timestep,
124
+ ):
125
  static_latents.copy_(latents)
126
  static_prompt_embeds.copy_(prompt_embeds)
127
  static_pooled_prompt_embeds.copy_(pooled_prompt_embeds)
 
132
  return static_output
133
 
134
  # Only generate the last image in the sequence
135
+ img = pipe.generate_images(
136
+ prompt=prompt,
137
+ width=width,
138
+ height=height,
139
+ num_inference_steps=num_inference_steps,
140
+ generator=generator,
141
+ generate_with_graph=generate_with_graph,
142
+ )
143
+ latency = f"Latency: {(time.time()-start_time):.2f} seconds"
144
  return img, seed, latency
145
 
146
  # Example prompts
 
158
  with gr.Blocks() as demo:
159
  with gr.Column(elem_id="app-container"):
160
  gr.Markdown("# 🎨 Realtime FLUX Image Generator")
161
+ gr.Markdown(
162
+ "Generate stunning images in real-time with Modified Flux.Schnell pipeline."
163
+ )
164
+ gr.Markdown(
165
+ "<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>"
166
+ )
167
 
168
  with gr.Row():
169
  with gr.Column(scale=2.5):
170
+ result = gr.Image(
171
+ label="Generated Image", show_label=False, interactive=False
172
+ )
173
  with gr.Column(scale=1):
174
  prompt = gr.Text(
175
  label="Prompt",
 
183
 
184
  with gr.Column("Advanced Options"):
185
  with gr.Row():
186
+ realtime = gr.Checkbox(
187
+ label="Realtime Toggler",
188
+ info="If TRUE then uses more GPU but create image in realtime.",
189
+ value=False,
190
+ )
191
  latency = gr.Text(label="Latency")
192
  with gr.Row():
193
  seed = gr.Number(label="Seed", value=42)
194
+ randomize_seed = gr.Checkbox(
195
+ label="Randomize Seed", value=True
196
+ )
197
  with gr.Row():
198
+ width = gr.Slider(
199
+ label="Width",
200
+ minimum=256,
201
+ maximum=MAX_IMAGE_SIZE,
202
+ step=32,
203
+ value=DEFAULT_WIDTH,
204
+ )
205
+ height = gr.Slider(
206
+ label="Height",
207
+ minimum=256,
208
+ maximum=MAX_IMAGE_SIZE,
209
+ step=32,
210
+ value=DEFAULT_HEIGHT,
211
+ )
212
+ num_inference_steps = gr.Slider(
213
+ label="Inference Steps",
214
+ minimum=1,
215
+ maximum=4,
216
+ step=1,
217
+ value=DEFAULT_INFERENCE_STEPS,
218
+ )
219
 
220
  with gr.Row():
221
  gr.Markdown("### 🌟 Inspiration Gallery")
 
225
  fn=generate_image,
226
  inputs=[prompt],
227
  outputs=[result, seed, latency],
228
+ cache_examples="lazy",
229
  )
230
 
231
  enhanceBtn.click(
 
234
  outputs=[result, seed, latency],
235
  show_progress="full",
236
  queue=False,
237
+ concurrency_limit=None,
238
  )
239
 
240
  generateBtn.click(
 
243
  outputs=[result, seed, latency],
244
  show_progress="full",
245
  api_name="RealtimeFlux",
246
+ queue=False,
247
  )
248
 
249
  def update_ui(realtime_enabled):
250
  return {
251
  prompt: gr.update(interactive=True),
252
+ generateBtn: gr.update(visible=not realtime_enabled),
253
  }
254
 
255
  realtime.change(
 
257
  inputs=[realtime],
258
  outputs=[prompt, generateBtn],
259
  queue=False,
260
+ concurrency_limit=None,
261
  )
262
 
263
  async def realtime_generation(*args):
 
272
  outputs=[result, seed, latency],
273
  show_progress="full",
274
  queue=False,
275
+ concurrency_limit=None,
276
  )
277
 
278
  for component in [prompt, width, height, num_inference_steps]:
279
  component.input(
280
  fn=realtime_generation,
281
+ inputs=[
282
+ realtime,
283
+ prompt,
284
+ seed,
285
+ width,
286
+ height,
287
+ randomize_seed,
288
+ num_inference_steps,
289
+ ],
290
  outputs=[result, seed, latency],
291
  show_progress="hidden",
292
  trigger_mode="always_last",
293
  queue=True,
294
+ concurrency_limit=None,
295
  )
296
 
297
  # Launch the app