prithivMLmods commited on
Commit
190a279
·
verified ·
1 Parent(s): 924fbec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -302
app.py CHANGED
@@ -2,17 +2,16 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
6
  import random
7
  import uuid
8
  from typing import Tuple, Union, List, Optional, Any, Dict
9
  import numpy as np
10
  import time
11
  import zipfile
12
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
13
 
14
  # Description for the app
15
- DESCRIPTION = """## flux-krea vs qwen"""
16
 
17
  # Helper functions
18
  def save_image(img):
@@ -28,175 +27,11 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 2048
30
 
31
- # Load pipelines
32
  dtype = torch.bfloat16
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
-
35
- # Flux.1-krea pipeline
36
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
37
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype).to(device)
38
- pipe_krea = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=taef1).to(device)
39
-
40
- # Qwen/Qwen-Image pipeline
41
- pipe_qwen = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
42
-
43
- # Define custom flux_pipe_call for Flux.1-krea
44
- @torch.inference_mode()
45
- def flux_pipe_call_that_returns_an_iterable_of_images(
46
- self,
47
- prompt: Union[str, List[str]] = None,
48
- prompt_2: Optional[Union[str, List[str]]] = None,
49
- height: Optional[int] = None,
50
- width: Optional[int] = None,
51
- num_inference_steps: int = 28,
52
- timesteps: List[int] = None,
53
- guidance_scale: float = 3.5,
54
- num_images_per_prompt: Optional[int] = 1,
55
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
56
- latents: Optional[torch.FloatTensor] = None,
57
- prompt_embeds: Optional[torch.FloatTensor] = None,
58
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
59
- output_type: Optional[str] = "pil",
60
- return_dict: bool = True,
61
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
62
- max_sequence_length: int = 512,
63
- good_vae: Optional[Any] = None,
64
- ):
65
- height = height or self.default_sample_size * self.vae_scale_factor
66
- width = width or self.default_sample_size * self.vae_scale_factor
67
-
68
- self.check_inputs(
69
- prompt,
70
- prompt_2,
71
- height,
72
- width,
73
- prompt_embeds=prompt_embeds,
74
- pooled_prompt_embeds=pooled_prompt_embeds,
75
- max_sequence_length=max_sequence_length,
76
- )
77
-
78
- self._guidance_scale = guidance_scale
79
- self._joint_attention_kwargs = joint_attention_kwargs
80
- self._interrupt = False
81
-
82
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
83
- device = self._execution_device
84
-
85
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
86
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
87
- prompt=prompt,
88
- prompt_2=prompt_2,
89
- prompt_embeds=prompt_embeds,
90
- pooled_prompt_embeds=pooled_prompt_embeds,
91
- device=device,
92
- num_images_per_prompt=num_images_per_prompt,
93
- max_sequence_length=max_sequence_length,
94
- lora_scale=lora_scale,
95
- )
96
-
97
- num_channels_latents = self.transformer.config.in_channels // 4
98
- latents, latent_image_ids = self.prepare_latents(
99
- batch_size * num_images_per_prompt,
100
- num_channels_latents,
101
- height,
102
- width,
103
- prompt_embeds.dtype,
104
- device,
105
- generator,
106
- latents,
107
- )
108
-
109
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
110
- image_seq_len = latents.shape[1]
111
- mu = calculate_shift(
112
- image_seq_len,
113
- self.scheduler.config.base_image_seq_len,
114
- self.scheduler.config.max_image_seq_len,
115
- self.scheduler.config.base_shift,
116
- self.scheduler.config.max_shift,
117
- )
118
- timesteps, num_inference_steps = retrieve_timesteps(
119
- self.scheduler,
120
- num_inference_steps,
121
- device,
122
- timesteps,
123
- sigmas,
124
- mu=mu,
125
- )
126
- self._num_timesteps = len(timesteps)
127
-
128
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
129
-
130
- for i, t in enumerate(timesteps):
131
- if self.interrupt:
132
- continue
133
-
134
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
135
-
136
- noise_pred = self.transformer(
137
- hidden_states=latents,
138
- timestep=timestep / 1000,
139
- guidance=guidance,
140
- pooled_projections=pooled_prompt_embeds,
141
- encoder_hidden_states=prompt_embeds,
142
- txt_ids=text_ids,
143
- img_ids=latent_image_ids,
144
- joint_attention_kwargs=self.joint_attention_kwargs,
145
- return_dict=False,
146
- )[0]
147
-
148
- latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
149
- latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
150
- image = self.vae.decode(latents_for_image, return_dict=False)[0]
151
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
152
-
153
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
154
- torch.cuda.empty_cache()
155
-
156
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
157
- latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
158
- image = good_vae.decode(latents, return_dict=False)[0]
159
- self.maybe_free_model_hooks()
160
- torch.cuda.empty_cache()
161
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
162
-
163
- pipe_krea.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe_krea)
164
-
165
- # Helper functions for Flux.1-krea
166
- def calculate_shift(
167
- image_seq_len,
168
- base_seq_len: int = 256,
169
- max_seq_len: int = 4096,
170
- base_shift: float = 0.5,
171
- max_shift: float = 1.16,
172
- ):
173
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
174
- b = base_shift - m * base_seq_len
175
- mu = image_seq_len * m + b
176
- return mu
177
-
178
- def retrieve_timesteps(
179
- scheduler,
180
- num_inference_steps: Optional[int] = None,
181
- device: Optional[Union[str, torch.device]] = None,
182
- timesteps: Optional[List[int]] = None,
183
- sigmas: Optional[List[float]] = None,
184
- **kwargs,
185
- ):
186
- if timesteps is not None and sigmas is not None:
187
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
188
- if timesteps is not None:
189
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
190
- timesteps = scheduler.timesteps
191
- num_inference_steps = len(timesteps)
192
- elif sigmas is not None:
193
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
194
- timesteps = scheduler.timesteps
195
- num_inference_steps = len(timesteps)
196
- else:
197
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
198
- timesteps = scheduler.timesteps
199
- return timesteps, num_inference_steps
200
 
201
  # Aspect ratios
202
  aspect_ratios = {
@@ -207,55 +42,6 @@ aspect_ratios = {
207
  "3:4": (1140, 1472)
208
  }
209
 
210
- # Generation function for Flux.1-krea
211
- @spaces.GPU
212
- def generate_krea(
213
- prompt: str,
214
- seed: int = 0,
215
- width: int = 1024,
216
- height: int = 1024,
217
- guidance_scale: float = 4.5,
218
- randomize_seed: bool = False,
219
- num_inference_steps: int = 28,
220
- num_images: int = 1,
221
- zip_images: bool = False,
222
- progress=gr.Progress(track_tqdm=True),
223
- ):
224
- if randomize_seed:
225
- seed = random.randint(0, MAX_SEED)
226
- generator = torch.Generator(device).manual_seed(seed)
227
-
228
- start_time = time.time()
229
-
230
- images = []
231
- for _ in range(num_images):
232
- final_img = list(pipe_krea.flux_pipe_call_that_returns_an_iterable_of_images(
233
- prompt=prompt,
234
- guidance_scale=guidance_scale,
235
- num_inference_steps=num_inference_steps,
236
- width=width,
237
- height=height,
238
- generator=generator,
239
- output_type="pil",
240
- good_vae=good_vae,
241
- ))[-1] # Take the final image only
242
- images.append(final_img)
243
-
244
- end_time = time.time()
245
- duration = end_time - start_time
246
-
247
- image_paths = [save_image(img) for img in images]
248
-
249
- zip_path = None
250
- if zip_images:
251
- zip_name = str(uuid.uuid4()) + ".zip"
252
- with zipfile.ZipFile(zip_name, 'w') as zipf:
253
- for i, img_path in enumerate(image_paths):
254
- zipf.write(img_path, arcname=f"Img_{i}.png")
255
- zip_path = zip_name
256
-
257
- return image_paths, seed, f"{duration:.2f}", zip_path
258
-
259
  # Generation function for Qwen/Qwen-Image
260
  @spaces.GPU
261
  def generate_qwen(
@@ -304,54 +90,6 @@ def generate_qwen(
304
 
305
  return image_paths, seed, f"{duration:.2f}", zip_path
306
 
307
- # Main generation function
308
- @spaces.GPU
309
- def generate(
310
- model_choice: str,
311
- prompt: str,
312
- negative_prompt: str = "",
313
- use_negative_prompt: bool = False,
314
- seed: int = 0,
315
- width: int = 1024,
316
- height: int = 1024,
317
- guidance_scale: float = 3.5,
318
- randomize_seed: bool = False,
319
- num_inference_steps: int = 28,
320
- num_images: int = 1,
321
- zip_images: bool = False,
322
- progress=gr.Progress(track_tqdm=True),
323
- ):
324
- if model_choice == "Flux.1-krea":
325
- return generate_krea(
326
- prompt=prompt,
327
- seed=seed,
328
- width=width,
329
- height=height,
330
- guidance_scale=guidance_scale,
331
- randomize_seed=randomize_seed,
332
- num_inference_steps=num_inference_steps,
333
- num_images=num_images,
334
- zip_images=zip_images,
335
- progress=progress,
336
- )
337
- elif model_choice == "Qwen Image":
338
- final_negative_prompt = negative_prompt if use_negative_prompt else ""
339
- return generate_qwen(
340
- prompt=prompt,
341
- negative_prompt=final_negative_prompt,
342
- seed=seed,
343
- width=width,
344
- height=height,
345
- guidance_scale=guidance_scale,
346
- randomize_seed=randomize_seed,
347
- num_inference_steps=num_inference_steps,
348
- num_images=num_images,
349
- zip_images=zip_images,
350
- progress=progress,
351
- )
352
- else:
353
- raise ValueError("Invalid model choice")
354
-
355
  # Examples
356
  examples = [
357
  "An attractive young woman with blue eyes lying face down on the bed, light white and light amber, timeless beauty, sunrays shine upon it",
@@ -387,13 +125,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
387
  run_button = gr.Button("Run", scale=0, variant="primary")
388
  result = gr.Gallery(label="Result", columns=1, show_label=False, preview=True)
389
 
390
- with gr.Row():
391
- model_choice = gr.Radio(
392
- choices=["Flux.1-krea", "Qwen Image"],
393
- label="Select Model",
394
- value="Flux.1-krea"
395
- )
396
-
397
  with gr.Accordion("Additional Options", open=False):
398
  aspect_ratio = gr.Dropdown(
399
  label="Aspect Ratio",
@@ -401,9 +132,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
401
  value="1:1",
402
  )
403
  use_negative_prompt = gr.Checkbox(
404
- label="Use negative prompt (Qwen Image only)",
405
  value=False,
406
- visible=False
407
  )
408
  negative_prompt = gr.Text(
409
  label="Negative prompt",
@@ -439,14 +169,14 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
439
  minimum=0.0,
440
  maximum=20.0,
441
  step=0.1,
442
- value=3.5,
443
  )
444
  num_inference_steps = gr.Slider(
445
  label="Number of inference steps",
446
  minimum=1,
447
  maximum=100,
448
  step=1,
449
- value=28,
450
  )
451
  num_images = gr.Slider(
452
  label="Number of images",
@@ -473,27 +203,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
473
  outputs=[width, height]
474
  )
475
 
476
- # Update model-specific settings
477
- def update_settings(mc):
478
- if mc == "Flux.1-krea":
479
- return (
480
- gr.update(value=28),
481
- gr.update(value=3.5),
482
- gr.update(visible=False)
483
- )
484
- elif mc == "Qwen Image":
485
- return (
486
- gr.update(value=50),
487
- gr.update(value=4.0),
488
- gr.update(visible=True)
489
- )
490
-
491
- model_choice.change(
492
- fn=update_settings,
493
- inputs=model_choice,
494
- outputs=[num_inference_steps, guidance_scale, use_negative_prompt]
495
- )
496
-
497
  # Negative prompt visibility
498
  use_negative_prompt.change(
499
  fn=lambda x: gr.update(visible=x),
@@ -504,12 +213,10 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
504
  # Run button and prompt submit
505
  gr.on(
506
  triggers=[prompt.submit, run_button.click],
507
- fn=generate,
508
  inputs=[
509
- model_choice,
510
  prompt,
511
  negative_prompt,
512
- use_negative_prompt,
513
  seed,
514
  width,
515
  height,
@@ -528,7 +235,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
528
  examples=examples,
529
  inputs=prompt,
530
  outputs=[result, seed_display, generation_time, zip_file],
531
- fn=generate,
532
  cache_examples=False,
533
  )
534
 
 
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
+ from diffusers import DiffusionPipeline, AutoencoderTiny
6
  import random
7
  import uuid
8
  from typing import Tuple, Union, List, Optional, Any, Dict
9
  import numpy as np
10
  import time
11
  import zipfile
 
12
 
13
  # Description for the app
14
+ DESCRIPTION = """## Qwen Image Hpc/."""
15
 
16
  # Helper functions
17
  def save_image(img):
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  MAX_IMAGE_SIZE = 2048
29
 
30
+ # Load Qwen/Qwen-Image pipeline with taef1 VAE
31
  dtype = torch.bfloat16
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
33
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
34
+ pipe_qwen = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype, vae=taef1).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Aspect ratios
37
  aspect_ratios = {
 
42
  "3:4": (1140, 1472)
43
  }
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Generation function for Qwen/Qwen-Image
46
  @spaces.GPU
47
  def generate_qwen(
 
90
 
91
  return image_paths, seed, f"{duration:.2f}", zip_path
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  # Examples
94
  examples = [
95
  "An attractive young woman with blue eyes lying face down on the bed, light white and light amber, timeless beauty, sunrays shine upon it",
 
125
  run_button = gr.Button("Run", scale=0, variant="primary")
126
  result = gr.Gallery(label="Result", columns=1, show_label=False, preview=True)
127
 
 
 
 
 
 
 
 
128
  with gr.Accordion("Additional Options", open=False):
129
  aspect_ratio = gr.Dropdown(
130
  label="Aspect Ratio",
 
132
  value="1:1",
133
  )
134
  use_negative_prompt = gr.Checkbox(
135
+ label="Use negative prompt",
136
  value=False,
 
137
  )
138
  negative_prompt = gr.Text(
139
  label="Negative prompt",
 
169
  minimum=0.0,
170
  maximum=20.0,
171
  step=0.1,
172
+ value=4.0,
173
  )
174
  num_inference_steps = gr.Slider(
175
  label="Number of inference steps",
176
  minimum=1,
177
  maximum=100,
178
  step=1,
179
+ value=50,
180
  )
181
  num_images = gr.Slider(
182
  label="Number of images",
 
203
  outputs=[width, height]
204
  )
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  # Negative prompt visibility
207
  use_negative_prompt.change(
208
  fn=lambda x: gr.update(visible=x),
 
213
  # Run button and prompt submit
214
  gr.on(
215
  triggers=[prompt.submit, run_button.click],
216
+ fn=generate_qwen,
217
  inputs=[
 
218
  prompt,
219
  negative_prompt,
 
220
  seed,
221
  width,
222
  height,
 
235
  examples=examples,
236
  inputs=prompt,
237
  outputs=[result, seed_display, generation_time, zip_file],
238
+ fn=generate_qwen,
239
  cache_examples=False,
240
  )
241