Ryukijano commited on
Commit
858fb7b
·
verified ·
1 Parent(s): 105e0dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -112
app.py CHANGED
@@ -1,16 +1,15 @@
1
- import torch
2
-
3
- torch.backends.cuda.matmul.allow_tf32 = True
4
- torch.backends.cudnn.allow_tf32 = True
5
  import gradio as gr
6
  import numpy as np
7
  import random
8
  import spaces
 
9
  import time
10
  from diffusers import DiffusionPipeline, AutoencoderTiny
11
  from diffusers.models.attention_processor import AttnProcessor2_0
12
  from custom_pipeline import FluxWithCFGPipeline
13
 
 
 
14
  # Constants
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
@@ -25,72 +24,110 @@ pipe = FluxWithCFGPipeline.from_pretrained(
25
  )
26
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
27
  pipe.to("cuda")
28
- pipe.load_lora_weights(
29
- "hugovntr/flux-schnell-realism",
30
- weight_name="schnell-realism_v2.3.safetensors",
31
- adapter_name="better",
32
- )
33
  pipe.set_adapters(["better"], adapter_weights=[1.0])
34
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
35
  pipe.unload_lora_weights()
36
 
37
- # Correctly set memory format
38
- pipe.transformer.to(memory_format=torch.channels_last)
39
- pipe.vae.to(memory_format=torch.channels_last)
40
-
41
- # Conditionally enable xformers only for the transformer
42
- if hasattr(pipe, "transformer") and torch.cuda.is_available():
43
- try:
44
- pipe.transformer.enable_xformers_memory_efficient_attention()
45
- except Exception as e:
46
- print(
47
- "Warning: Could not enable xformers for the transformer due to the following error:"
48
- )
49
- print(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- torch.cuda.empty_cache()
 
 
 
 
52
 
53
  # Inference function
54
  @spaces.GPU(duration=25)
55
- def generate_image(
56
- prompt,
57
- seed=24,
58
- width=DEFAULT_WIDTH,
59
- height=DEFAULT_HEIGHT,
60
- randomize_seed=False,
61
- num_inference_steps=2,
62
- progress=gr.Progress(track_tqdm=True),
63
- ):
64
  if randomize_seed:
65
  seed = random.randint(0, MAX_SEED)
66
  generator = torch.Generator().manual_seed(int(float(seed)))
67
 
68
  start_time = time.time()
 
 
 
69
 
70
- # Dynamically determine shapes based on input width/height
71
- latents_shape = (1, 4, height // 8, width // 8)
72
- prompt_embeds_shape = (
73
- 1,
74
- pipe.transformer.text_encoder.config.max_position_embeddings,
75
- pipe.transformer.text_encoder.config.hidden_size,
76
- )
77
- pooled_prompt_embeds_shape = (
78
  1,
79
- pipe.transformer.text_encoder.config.hidden_size,
80
- )
81
-
82
- # Only generate the last image in the sequence
83
- img = pipe.generate_images(
84
- prompt=prompt,
85
- width=width,
86
- height=height,
87
- num_inference_steps=num_inference_steps,
88
- generator=generator,
89
- latents_shape=latents_shape,
90
- prompt_embeds_shape=prompt_embeds_shape,
91
- pooled_prompt_embeds_shape=pooled_prompt_embeds_shape
92
- )
93
- latency = f"Latency: {(time.time()-start_time):.2f} seconds"
94
  return img, seed, latency
95
 
96
  # Example prompts
@@ -108,18 +145,12 @@ examples = [
108
  with gr.Blocks() as demo:
109
  with gr.Column(elem_id="app-container"):
110
  gr.Markdown("# 🎨 Realtime FLUX Image Generator")
111
- gr.Markdown(
112
- "Generate stunning images in real-time with Modified Flux.Schnell pipeline."
113
- )
114
- gr.Markdown(
115
- "<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>"
116
- )
117
 
118
  with gr.Row():
119
  with gr.Column(scale=2.5):
120
- result = gr.Image(
121
- label="Generated Image", show_label=False, interactive=False
122
- )
123
  with gr.Column(scale=1):
124
  prompt = gr.Text(
125
  label="Prompt",
@@ -133,39 +164,15 @@ with gr.Blocks() as demo:
133
 
134
  with gr.Column("Advanced Options"):
135
  with gr.Row():
136
- realtime = gr.Checkbox(
137
- label="Realtime Toggler",
138
- info="If TRUE then uses more GPU but create image in realtime.",
139
- value=False,
140
- )
141
  latency = gr.Text(label="Latency")
142
  with gr.Row():
143
  seed = gr.Number(label="Seed", value=42)
144
- randomize_seed = gr.Checkbox(
145
- label="Randomize Seed", value=True
146
- )
147
  with gr.Row():
148
- width = gr.Slider(
149
- label="Width",
150
- minimum=256,
151
- maximum=MAX_IMAGE_SIZE,
152
- step=32,
153
- value=DEFAULT_WIDTH,
154
- )
155
- height = gr.Slider(
156
- label="Height",
157
- minimum=256,
158
- maximum=MAX_IMAGE_SIZE,
159
- step=32,
160
- value=DEFAULT_HEIGHT,
161
- )
162
- num_inference_steps = gr.Slider(
163
- label="Inference Steps",
164
- minimum=1,
165
- maximum=4,
166
- step=1,
167
- value=DEFAULT_INFERENCE_STEPS,
168
- )
169
 
170
  with gr.Row():
171
  gr.Markdown("### 🌟 Inspiration Gallery")
@@ -175,7 +182,7 @@ with gr.Blocks() as demo:
175
  fn=generate_image,
176
  inputs=[prompt],
177
  outputs=[result, seed, latency],
178
- cache_examples="lazy",
179
  )
180
 
181
  enhanceBtn.click(
@@ -184,7 +191,7 @@ with gr.Blocks() as demo:
184
  outputs=[result, seed, latency],
185
  show_progress="full",
186
  queue=False,
187
- concurrency_limit=None,
188
  )
189
 
190
  generateBtn.click(
@@ -199,7 +206,7 @@ with gr.Blocks() as demo:
199
  def update_ui(realtime_enabled):
200
  return {
201
  prompt: gr.update(interactive=True),
202
- generateBtn: gr.update(visible=not realtime_enabled),
203
  }
204
 
205
  realtime.change(
@@ -207,13 +214,12 @@ with gr.Blocks() as demo:
207
  inputs=[realtime],
208
  outputs=[prompt, generateBtn],
209
  queue=False,
210
- concurrency_limit=None,
211
  )
212
 
213
  def realtime_generation(*args):
214
  if args[0]: # If realtime is enabled
215
- img, seed, latency = generate_image(*args[1:])
216
- return img, seed, latency
217
 
218
  prompt.submit(
219
  fn=generate_image,
@@ -221,27 +227,19 @@ with gr.Blocks() as demo:
221
  outputs=[result, seed, latency],
222
  show_progress="full",
223
  queue=False,
224
- concurrency_limit=None,
225
  )
226
 
227
  for component in [prompt, width, height, num_inference_steps]:
228
  component.input(
229
  fn=realtime_generation,
230
- inputs=[
231
- realtime,
232
- prompt,
233
- seed,
234
- width,
235
- height,
236
- randomize_seed,
237
- num_inference_steps,
238
- ],
239
  outputs=[result, seed, latency],
240
  show_progress="hidden",
241
  trigger_mode="always_last",
242
- queue=True,
243
- concurrency_limit=None,
244
  )
245
 
246
  # Launch the app
247
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  import spaces
5
+ import torch
6
  import time
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
 
11
+ torch.backends.cuda.matmul.allow_tf32 = True
12
+
13
  # Constants
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
 
24
  )
25
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
26
  pipe.to("cuda")
27
+ pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
 
 
 
 
28
  pipe.set_adapters(["better"], adapter_weights=[1.0])
29
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
30
  pipe.unload_lora_weights()
31
 
32
+ # Memory optimizations
33
+ pipe.unet.to(memory_format=torch.channels_last) # Channels last
34
+ pipe.enable_xformers_memory_efficient_attention() # Flash Attention
35
+
36
+ # CUDA Graph setup
37
+ static_inputs = None
38
+ static_model = None
39
+ graph = None
40
+
41
+ def setup_cuda_graph(prompt, height, width, num_inference_steps):
42
+ global static_inputs, static_model, graph
43
+
44
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
45
+ device = "cuda"
46
+ num_images_per_prompt = 1
47
+
48
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
49
+ prompt=prompt,
50
+ prompt_2=None,
51
+ prompt_embeds=None,
52
+ pooled_prompt_embeds=None,
53
+ device=device,
54
+ num_images_per_prompt=num_images_per_prompt,
55
+ max_sequence_length=300,
56
+ lora_scale=None,
57
+ )
58
+
59
+ latents, latent_image_ids = pipe.prepare_latents(
60
+ batch_size * num_images_per_prompt,
61
+ pipe.transformer.config.in_channels // 4,
62
+ height,
63
+ width,
64
+ prompt_embeds.dtype,
65
+ device,
66
+ None,
67
+ None,
68
+ )
69
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
70
+ image_seq_len = latents.shape[1]
71
+ mu = calculate_timestep_shift(image_seq_len)
72
+
73
+ timesteps, num_inference_steps = prepare_timesteps(
74
+ pipe.scheduler,
75
+ num_inference_steps,
76
+ device,
77
+ None,
78
+ sigmas,
79
+ mu=mu,
80
+ )
81
+
82
+ guidance = torch.full([1], 3.5, device=device, dtype=torch.float16).expand(latents.shape[0]) if pipe.transformer.config.guidance_embeds else None
83
+
84
+ static_inputs = {
85
+ "hidden_states": latents,
86
+ "timestep": timesteps,
87
+ "guidance": guidance,
88
+ "pooled_projections": pooled_prompt_embeds,
89
+ "encoder_hidden_states": prompt_embeds,
90
+ "txt_ids": text_ids,
91
+ "img_ids": latent_image_ids,
92
+ "joint_attention_kwargs": None,
93
+ }
94
 
95
+ static_model = torch.cuda.make_graphed_callables(pipe.transformer, (static_inputs,))
96
+ graph = torch.cuda.CUDAGraph()
97
+
98
+ with torch.cuda.graph(graph):
99
+ static_output = static_model(**static_inputs)
100
 
101
  # Inference function
102
  @spaces.GPU(duration=25)
103
+ def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
104
+ global static_inputs, graph
105
+
 
 
 
 
 
 
106
  if randomize_seed:
107
  seed = random.randint(0, MAX_SEED)
108
  generator = torch.Generator().manual_seed(int(float(seed)))
109
 
110
  start_time = time.time()
111
+
112
+ if static_inputs is None:
113
+ setup_cuda_graph(prompt, height, width, num_inference_steps)
114
 
115
+ static_inputs["hidden_states"].copy_(pipe.prepare_latents(
 
 
 
 
 
 
 
116
  1,
117
+ pipe.transformer.config.in_channels // 4,
118
+ height,
119
+ width,
120
+ static_inputs["encoder_hidden_states"].dtype,
121
+ "cuda",
122
+ generator,
123
+ None,
124
+ )[0])
125
+
126
+ graph.replay()
127
+ latents = static_inputs["hidden_states"]
128
+
129
+ img = pipe._decode_latents_to_image(latents, height, width, "pil")
130
+ latency = f"Latency: {(time.time()-start_time):.2f} seconds"
 
131
  return img, seed, latency
132
 
133
  # Example prompts
 
145
  with gr.Blocks() as demo:
146
  with gr.Column(elem_id="app-container"):
147
  gr.Markdown("# 🎨 Realtime FLUX Image Generator")
148
+ gr.Markdown("Generate stunning images in real-time with Modified Flux.Schnell pipeline.")
149
+ gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
 
 
 
 
150
 
151
  with gr.Row():
152
  with gr.Column(scale=2.5):
153
+ result = gr.Image(label="Generated Image", show_label=False, interactive=False)
 
 
154
  with gr.Column(scale=1):
155
  prompt = gr.Text(
156
  label="Prompt",
 
164
 
165
  with gr.Column("Advanced Options"):
166
  with gr.Row():
167
+ realtime = gr.Checkbox(label="Realtime Toggler", info="If TRUE then uses more GPU but create image in realtime.", value=False)
 
 
 
 
168
  latency = gr.Text(label="Latency")
169
  with gr.Row():
170
  seed = gr.Number(label="Seed", value=42)
171
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
172
  with gr.Row():
173
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
174
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
175
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  with gr.Row():
178
  gr.Markdown("### 🌟 Inspiration Gallery")
 
182
  fn=generate_image,
183
  inputs=[prompt],
184
  outputs=[result, seed, latency],
185
+ cache_examples="lazy"
186
  )
187
 
188
  enhanceBtn.click(
 
191
  outputs=[result, seed, latency],
192
  show_progress="full",
193
  queue=False,
194
+ concurrency_limit=None
195
  )
196
 
197
  generateBtn.click(
 
206
  def update_ui(realtime_enabled):
207
  return {
208
  prompt: gr.update(interactive=True),
209
+ generateBtn: gr.update(visible=not realtime_enabled)
210
  }
211
 
212
  realtime.change(
 
214
  inputs=[realtime],
215
  outputs=[prompt, generateBtn],
216
  queue=False,
217
+ concurrency_limit=None
218
  )
219
 
220
  def realtime_generation(*args):
221
  if args[0]: # If realtime is enabled
222
+ return next(generate_image(*args[1:]))
 
223
 
224
  prompt.submit(
225
  fn=generate_image,
 
227
  outputs=[result, seed, latency],
228
  show_progress="full",
229
  queue=False,
230
+ concurrency_limit=None
231
  )
232
 
233
  for component in [prompt, width, height, num_inference_steps]:
234
  component.input(
235
  fn=realtime_generation,
236
+ inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
 
 
 
 
 
 
 
 
237
  outputs=[result, seed, latency],
238
  show_progress="hidden",
239
  trigger_mode="always_last",
240
+ queue=False,
241
+ concurrency_limit=None
242
  )
243
 
244
  # Launch the app
245
+ demo.queue(max_size=5, concurrency_count=1).launch()