prithivMLmods commited on
Commit
ea8b158
·
verified ·
1 Parent(s): e03a12c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -35
app.py CHANGED
@@ -2,17 +2,20 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
- from diffusers import DiffusionPipeline
6
  import random
7
  import uuid
8
- from typing import Tuple
9
  import numpy as np
10
  import time
11
  import zipfile
 
12
 
13
- DESCRIPTION = """## flux realism hpc/.
14
- """
 
15
 
 
16
  def save_image(img):
17
  unique_name = str(uuid.uuid4()) + ".png"
18
  img.save(unique_name)
@@ -24,37 +27,188 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
24
  return seed
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
 
27
 
28
- base_model = "black-forest-labs/FLUX.1-dev"
29
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
30
-
 
31
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
32
  trigger_word = "Super Realism"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- pipe.load_lora_weights(lora_repo)
35
- pipe.to("cuda")
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  style_list = [
38
- {
39
- "name": "3840 x 2160",
40
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
41
- "negative_prompt": "",
42
- },
43
- {
44
- "name": "2560 x 1440",
45
- "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
46
- "negative_prompt": "",
47
- },
48
- {
49
- "name": "HD+",
50
- "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
51
- "negative_prompt": "",
52
- },
53
- {
54
- "name": "Style Zero",
55
- "prompt": "{prompt}",
56
- "negative_prompt": "",
57
- },
58
  ]
59
 
60
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
@@ -65,8 +219,9 @@ def apply_style(style_name: str, positive: str) -> Tuple[str, str]:
65
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
66
  return p.replace("{prompt}", positive), n
67
 
 
68
  @spaces.GPU
69
- def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
72
  use_negative_prompt: bool = False,
@@ -98,7 +253,7 @@ def generate(
98
 
99
  start_time = time.time()
100
 
101
- images = pipe(
102
  prompt=positive_prompt,
103
  negative_prompt=final_negative_prompt if final_negative_prompt else None,
104
  width=width,
@@ -125,11 +280,111 @@ def generate(
125
 
126
  return image_paths, seed, f"{duration:.2f}", zip_path
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  examples = [
129
  "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
130
  "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
131
  "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
132
- "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
133
  ]
134
 
135
  css = '''
@@ -145,6 +400,7 @@ footer {
145
  }
146
  '''
147
 
 
148
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
149
  gr.Markdown(DESCRIPTION)
150
  with gr.Row():
@@ -157,15 +413,22 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
157
  )
158
  run_button = gr.Button("Run", scale=0, variant="primary")
159
  result = gr.Gallery(label="Result", columns=1, show_label=False, preview=True)
160
-
 
 
 
 
 
 
 
161
  with gr.Accordion("Additional Options", open=False):
162
  style_selection = gr.Dropdown(
163
- label="Quality Style",
164
  choices=STYLE_NAMES,
165
  value=DEFAULT_STYLE_NAME,
166
  interactive=True,
167
  )
168
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
169
  negative_prompt = gr.Text(
170
  label="Negative prompt",
171
  max_lines=1,
@@ -245,6 +508,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
245
  ],
246
  fn=generate,
247
  inputs=[
 
248
  prompt,
249
  negative_prompt,
250
  use_negative_prompt,
 
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
6
  import random
7
  import uuid
8
+ from typing import Tuple, Union, List, Optional, Any, Dict
9
  import numpy as np
10
  import time
11
  import zipfile
12
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
13
 
14
+ # Description for the app
15
+ DESCRIPTION = """## Flux Realism HPC with Krea Integration
16
+ Choose between 'flux.1-dev-realism' for hyper-realistic images or 'flux.1-krea' for creative outputs."""
17
 
18
+ # Helper functions
19
  def save_image(img):
20
  unique_name = str(uuid.uuid4()) + ".png"
21
  img.save(unique_name)
 
27
  return seed
28
 
29
  MAX_SEED = np.iinfo(np.int32).max
30
+ MAX_IMAGE_SIZE = 2048
31
 
32
+ # Load pipelines for both models
33
+ # Flux.1-dev-realism
34
+ base_model_dev = "black-forest-labs/FLUX.1-dev"
35
+ pipe_dev = DiffusionPipeline.from_pretrained(base_model_dev, torch_dtype=torch.bfloat16)
36
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
37
  trigger_word = "Super Realism"
38
+ pipe_dev.load_lora_weights(lora_repo)
39
+ pipe_dev.to("cuda")
40
+
41
+ # Flux.1-krea
42
+ dtype = torch.bfloat16
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
45
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype).to(device)
46
+ pipe_krea = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=taef1).to(device)
47
+
48
+ # Define the flux_pipe_call_that_returns_an_iterable_of_images for flux.1-krea
49
+ @torch.inference_mode()
50
+ def flux_pipe_call_that_returns_an_iterable_of_images(
51
+ self,
52
+ prompt: Union[str, List[str]] = None,
53
+ prompt_2: Optional[Union[str, List[str]]] = None,
54
+ height: Optional[int] = None,
55
+ width: Optional[int] = None,
56
+ num_inference_steps: int = 28,
57
+ timesteps: List[int] = None,
58
+ guidance_scale: float = 3.5,
59
+ num_images_per_prompt: Optional[int] = 1,
60
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61
+ latents: Optional[torch.FloatTensor] = None,
62
+ prompt_embeds: Optional[torch.FloatTensor] = None,
63
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
64
+ output_type: Optional[str] = "pil",
65
+ return_dict: bool = True,
66
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
67
+ max_sequence_length: int = 512,
68
+ good_vae: Optional[Any] = None,
69
+ ):
70
+ height = height or self.default_sample_size * self.vae_scale_factor
71
+ width = width or self.default_sample_size * self.vae_scale_factor
72
+
73
+ self.check_inputs(
74
+ prompt,
75
+ prompt_2,
76
+ height,
77
+ width,
78
+ prompt_embeds=prompt_embeds,
79
+ pooled_prompt_embeds=pooled_prompt_embeds,
80
+ max_sequence_length=max_sequence_length,
81
+ )
82
+
83
+ self._guidance_scale = guidance_scale
84
+ self._joint_attention_kwargs = joint_attention_kwargs
85
+ self._interrupt = False
86
+
87
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
88
+ device = self._execution_device
89
+
90
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
91
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
92
+ prompt=prompt,
93
+ prompt_2=prompt_2,
94
+ prompt_embeds=prompt_embeds,
95
+ pooled_prompt_embeds=pooled_prompt_embeds,
96
+ device=device,
97
+ num_images_per_prompt=num_images_per_prompt,
98
+ max_sequence_length=max_sequence_length,
99
+ lora_scale=lora_scale,
100
+ )
101
+
102
+ num_channels_latents = self.transformer.config.in_channels // 4
103
+ latents, latent_image_ids = self.prepare_latents(
104
+ batch_size * num_images_per_prompt,
105
+ num_channels_latents,
106
+ height,
107
+ width,
108
+ prompt_embeds.dtype,
109
+ device,
110
+ generator,
111
+ latents,
112
+ )
113
+
114
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
115
+ image_seq_len = latents.shape[1]
116
+ mu = calculate_shift(
117
+ image_seq_len,
118
+ self.scheduler.config.base_image_seq_len,
119
+ self.scheduler.config.max_image_seq_len,
120
+ self.scheduler.config.base_shift,
121
+ self.scheduler.config.max_shift,
122
+ )
123
+ timesteps, num_inference_steps = retrieve_timesteps(
124
+ self.scheduler,
125
+ num_inference_steps,
126
+ device,
127
+ timesteps,
128
+ sigmas,
129
+ mu=mu,
130
+ )
131
+ self._num_timesteps = len(timesteps)
132
+
133
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
134
+
135
+ for i, t in enumerate(timesteps):
136
+ if self.interrupt:
137
+ continue
138
+
139
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
+
141
+ noise_pred = self.transformer(
142
+ hidden_states=latents,
143
+ timestep=timestep / 1000,
144
+ guidance=guidance,
145
+ pooled_projections=pooled_prompt_embeds,
146
+ encoder_hidden_states=prompt_embeds,
147
+ txt_ids=text_ids,
148
+ img_ids=latent_image_ids,
149
+ joint_attention_kwargs=self.joint_attention_kwargs,
150
+ return_dict=False,
151
+ )[0]
152
+
153
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
154
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
155
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
156
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
157
+
158
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
159
+ torch.cuda.empty_cache()
160
 
161
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
162
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
+ image = good_vae.decode(latents, return_dict=False)[0]
164
+ self.maybe_free_model_hooks()
165
+ torch.cuda.empty_cache()
166
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
167
 
168
+ pipe_krea.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe_krea)
169
+
170
+ # Helper functions for flux.1-krea
171
+ def calculate_shift(
172
+ image_seq_len,
173
+ base_seq_len: int = 256,
174
+ max_seq_len: int = 4096,
175
+ base_shift: float = 0.5,
176
+ max_shift: float = 1.16,
177
+ ):
178
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
179
+ b = base_shift - m * base_seq_len
180
+ mu = image_seq_len * m + b
181
+ return mu
182
+
183
+ def retrieve_timesteps(
184
+ scheduler,
185
+ num_inference_steps: Optional[int] = None,
186
+ device: Optional[Union[str, torch.device]] = None,
187
+ timesteps: Optional[List[int]] = None,
188
+ sigmas: Optional[List[float]] = None,
189
+ **kwargs,
190
+ ):
191
+ if timesteps is not None and sigmas is not None:
192
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
193
+ if timesteps is not None:
194
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
195
+ timesteps = scheduler.timesteps
196
+ num_inference_steps = len(timesteps)
197
+ elif sigmas is not None:
198
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
199
+ timesteps = scheduler.timesteps
200
+ num_inference_steps = len(timesteps)
201
+ else:
202
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
203
+ timesteps = scheduler.timesteps
204
+ return timesteps, num_inference_steps
205
+
206
+ # Styles for flux.1-dev-realism
207
  style_list = [
208
+ {"name": "3840 x 2160", "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", "negative_prompt": ""},
209
+ {"name": "2560 x 1440", "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", "negative_prompt": ""},
210
+ {"name": "HD+", "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", "negative_prompt": ""},
211
+ {"name": "Style Zero", "prompt": "{prompt}", "negative_prompt": ""},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  ]
213
 
214
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
 
219
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
220
  return p.replace("{prompt}", positive), n
221
 
222
+ # Generation function for flux.1-dev-realism
223
  @spaces.GPU
224
+ def generate_dev(
225
  prompt: str,
226
  negative_prompt: str = "",
227
  use_negative_prompt: bool = False,
 
253
 
254
  start_time = time.time()
255
 
256
+ images = pipe_dev(
257
  prompt=positive_prompt,
258
  negative_prompt=final_negative_prompt if final_negative_prompt else None,
259
  width=width,
 
280
 
281
  return image_paths, seed, f"{duration:.2f}", zip_path
282
 
283
+ # Generation function for flux.1-krea
284
+ @spaces.GPU
285
+ def generate_krea(
286
+ prompt: str,
287
+ seed: int = 0,
288
+ width: int = 1024,
289
+ height: int = 1024,
290
+ guidance_scale: float = 4.5,
291
+ randomize_seed: bool = False,
292
+ num_inference_steps: int = 28,
293
+ num_images: int = 1,
294
+ zip_images: bool = False,
295
+ progress=gr.Progress(track_tqdm=True),
296
+ ):
297
+ if randomize_seed:
298
+ seed = random.randint(0, MAX_SEED)
299
+ generator = torch.Generator().manual_seed(seed)
300
+
301
+ start_time = time.time()
302
+
303
+ images = []
304
+ for _ in range(num_images):
305
+ final_img = list(pipe_krea.flux_pipe_call_that_returns_an_iterable_of_images(
306
+ prompt=prompt,
307
+ guidance_scale=guidance_scale,
308
+ num_inference_steps=num_inference_steps,
309
+ width=width,
310
+ height=height,
311
+ generator=generator,
312
+ output_type="pil",
313
+ good_vae=good_vae,
314
+ ))[-1] # Take the final image only
315
+ images.append(final_img)
316
+
317
+ end_time = time.time()
318
+ duration = end_time - start_time
319
+
320
+ image_paths = [save_image(img) for img in images]
321
+
322
+ zip_path = None
323
+ if zip_images:
324
+ zip_name = str(uuid.uuid4()) + ".zip"
325
+ with zipfile.ZipFile(zip_name, 'w') as zipf:
326
+ for i, img_path in enumerate(image_paths):
327
+ zipf.write(img_path, arcname=f"Img_{i}.png")
328
+ zip_path = zip_name
329
+
330
+ return image_paths, seed, f"{duration:.2f}", zip_path
331
+
332
+ # Main generation function to handle model choice
333
+ @spaces.GPU
334
+ def generate(
335
+ model_choice: str,
336
+ prompt: str,
337
+ negative_prompt: str = "",
338
+ use_negative_prompt: bool = False,
339
+ seed: int = 0,
340
+ width: int = 1024,
341
+ height: int = 1024,
342
+ guidance_scale: float = 3,
343
+ randomize_seed: bool = False,
344
+ style_name: str = DEFAULT_STYLE_NAME,
345
+ num_inference_steps: int = 30,
346
+ num_images: int = 1,
347
+ zip_images: bool = False,
348
+ progress=gr.Progress(track_tqdm=True),
349
+ ):
350
+ if model_choice == "flux.1-dev-realism":
351
+ return generate_dev(
352
+ prompt=prompt,
353
+ negative_prompt=negative_prompt,
354
+ use_negative_prompt=use_negative_prompt,
355
+ seed=seed,
356
+ width=width,
357
+ height=height,
358
+ guidance_scale=guidance_scale,
359
+ randomize_seed=randomize_seed,
360
+ style_name=style_name,
361
+ num_inference_steps=num_inference_steps,
362
+ num_images=num_images,
363
+ zip_images=zip_images,
364
+ progress=progress,
365
+ )
366
+ elif model_choice == "flux.1-krea":
367
+ return generate_krea(
368
+ prompt=prompt,
369
+ seed=seed,
370
+ width=width,
371
+ height=height,
372
+ guidance_scale=guidance_scale,
373
+ randomize_seed=randomize_seed,
374
+ num_inference_steps=num_inference_steps,
375
+ num_images=num_images,
376
+ zip_images=zip_images,
377
+ progress=progress,
378
+ )
379
+ else:
380
+ raise ValueError("Invalid model choice")
381
+
382
+ # Examples (tailored for flux.1-dev-realism)
383
  examples = [
384
  "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
385
  "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
386
  "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
387
+ "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights."
388
  ]
389
 
390
  css = '''
 
400
  }
401
  '''
402
 
403
+ # Gradio interface
404
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
405
  gr.Markdown(DESCRIPTION)
406
  with gr.Row():
 
413
  )
414
  run_button = gr.Button("Run", scale=0, variant="primary")
415
  result = gr.Gallery(label="Result", columns=1, show_label=False, preview=True)
416
+
417
+ # Model choice radio button above additional options
418
+ model_choice = gr.Radio(
419
+ choices=["flux.1-krea", "flux.1-dev-realism"],
420
+ label="Select Model",
421
+ value="flux.1-krea"
422
+ )
423
+
424
  with gr.Accordion("Additional Options", open=False):
425
  style_selection = gr.Dropdown(
426
+ label="Quality Style (for flux.1-dev-realism only)",
427
  choices=STYLE_NAMES,
428
  value=DEFAULT_STYLE_NAME,
429
  interactive=True,
430
  )
431
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt (for flux.1-dev-realism only)", value=False)
432
  negative_prompt = gr.Text(
433
  label="Negative prompt",
434
  max_lines=1,
 
508
  ],
509
  fn=generate,
510
  inputs=[
511
+ model_choice,
512
  prompt,
513
  negative_prompt,
514
  use_negative_prompt,