elismasilva commited on
Commit
19688ef
Β·
1 Parent(s): e2482a3

update gradio version && improvements

Browse files
app.py β†’ app_mixture.py RENAMED
@@ -1,102 +1,99 @@
1
  import random
 
2
  import gradio as gr
3
  import numpy as np
4
  import spaces
5
  import torch
 
 
 
6
  from diffusers import AutoencoderKL
7
- from mixture_tiling_sdxl import StableDiffusionXLTilingPipeline
8
 
9
  MAX_SEED = np.iinfo(np.int32).max
10
- SCHEDULERS = [
11
- "LMSDiscreteScheduler",
12
- "DEISMultistepScheduler",
13
- "HeunDiscreteScheduler",
14
- "EulerAncestralDiscreteScheduler",
15
- "EulerDiscreteScheduler",
16
- "DPMSolverMultistepScheduler",
17
- "DPMSolverMultistepScheduler-Karras",
18
- "DPMSolverMultistepScheduler-Karras-SDE",
19
- "UniPCMultistepScheduler"
20
- ]
21
-
22
- vae = AutoencoderKL.from_pretrained(
23
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
24
- ).to("cuda")
25
 
26
- model_id="stablediffusionapi/yamermix-v8-vae"
 
 
27
  pipe = StableDiffusionXLTilingPipeline.from_pretrained(
28
  model_id,
29
  torch_dtype=torch.float16,
30
  vae=vae,
31
- use_safetensors=False, #for yammermix
32
- #variant="fp16",
33
  ).to("cuda")
34
 
35
- pipe.enable_model_cpu_offload() #<< Enable this if you have limited VRAM
36
  pipe.enable_vae_tiling()
37
  pipe.enable_vae_slicing()
38
 
39
- #region functions
40
- def select_scheduler(scheduler_name):
41
- scheduler = scheduler_name.split("-")
42
- scheduler_class_name = scheduler[0]
43
- add_kwargs = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear", "num_train_timesteps": 1000}
44
- if len(scheduler) > 1:
45
- add_kwargs["use_karras_sigmas"] = True
46
- if len(scheduler) > 2:
47
- add_kwargs["algorithm_type"] = "sde-dpmsolver++"
48
- import diffusers
49
- scheduler = getattr(diffusers, scheduler_class_name)
50
- scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
51
- return scheduler
52
-
53
-
54
 
 
55
  @spaces.GPU
56
- def predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  global pipe
58
-
59
  # Set selected scheduler
60
  print(f"Using scheduler: {scheduler}...")
61
- pipe.scheduler = select_scheduler(scheduler)
62
 
63
  # Set seed
64
  generator = torch.Generator("cuda").manual_seed(generation_seed)
65
-
66
  target_height = int(target_height)
67
  target_width = int(target_width)
68
  tile_height = int(tile_height)
69
  tile_width = int(tile_width)
70
-
71
  # Mixture of Diffusers generation
72
  image = pipe(
73
  prompt=[
74
  [
75
  left_prompt,
76
  center_prompt,
77
- right_prompt,
78
  ]
79
  ],
80
  negative_prompt=negative_prompt,
81
  tile_height=tile_height,
82
  tile_width=tile_width,
83
  tile_row_overlap=0,
84
- tile_col_overlap=overlap_pixels,
85
- guidance_scale_tiles=[[left_gs, center_gs, right_gs]],
86
  height=target_height,
87
- width=target_width,
88
  generator=generator,
89
  num_inference_steps=steps,
90
  )["images"][0]
91
 
 
92
  return image
93
 
 
94
  def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280):
95
- num_cols=3
96
- num_rows=1
97
- min_tile_dimension=8
98
- reduction_step=8
99
- max_tile_height_size=1024
100
  best_tile_width = 0
101
  best_tile_height = 0
102
  best_adjusted_target_width = 0
@@ -109,11 +106,11 @@ def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_s
109
 
110
  while tile_width >= min_tile_dimension:
111
  horizontal_borders = num_cols - 1
112
- total_horizontal_overlap_pixels = (overlap_pixels * horizontal_borders)
113
  adjusted_target_width = tile_width * num_cols - total_horizontal_overlap_pixels
114
 
115
  vertical_borders = num_rows - 1
116
- total_vertical_overlap_pixels = (overlap_pixels * vertical_borders)
117
  adjusted_target_height = tile_height * num_rows - total_vertical_overlap_pixels
118
 
119
  if tile_width <= max_tile_width_size and adjusted_target_width <= target_width:
@@ -131,15 +128,15 @@ def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_s
131
 
132
  while tile_height >= min_tile_dimension:
133
  horizontal_borders = num_cols - 1
134
- total_horizontal_overlap_pixels = (overlap_pixels * horizontal_borders)
135
  adjusted_target_width = tile_width * num_cols - total_horizontal_overlap_pixels
136
 
137
  vertical_borders = num_rows - 1
138
- total_vertical_overlap_pixels = (overlap_pixels * vertical_borders)
139
  adjusted_target_height = tile_height * num_rows - total_vertical_overlap_pixels
140
-
141
  if tile_height <= max_tile_height_size and adjusted_target_height <= target_height:
142
- if adjusted_target_height > best_adjusted_target_height:
143
  best_tile_height = tile_height
144
  best_adjusted_target_height = adjusted_target_height
145
 
@@ -150,7 +147,7 @@ def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_s
150
  tile_width = best_tile_width
151
  tile_height = best_tile_height
152
 
153
- print("--- TILE SIZE CALCULATED VALUES ---")
154
  print(f"Overlap pixels (requested): {overlap_pixels}")
155
  print(f"Tile Height (divisible by 8, max {max_tile_height_size}): {tile_height}")
156
  print(f"Tile Width (divisible by 8, max {max_tile_width_size}): {tile_width}")
@@ -163,32 +160,122 @@ def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_s
163
 
164
  return new_target_height, new_target_width, tile_height, tile_width
165
 
166
- def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size):
167
- new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(target_height, target_width, overlap_pixels, max_tile_size)
168
- return gr.update(value=tile_height), gr.update(value=tile_width), gr.update(value=new_target_height), gr.update(value=new_target_width)
 
 
 
 
 
 
 
 
 
169
 
170
  def clear_result():
171
  return gr.update(value=None)
172
 
173
- def run_for_examples(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width, max_tile_width):
174
- return predict(left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, scheduler, tile_height, tile_width, target_height, target_width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
177
  if randomize_seed:
178
  generation_seed = random.randint(0, MAX_SEED)
179
  return generation_seed
180
 
 
181
  css = """
182
- .gradio-container .fillable {
183
- width: 95% !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  max-width: unset !important;
185
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  """
187
- title = """<h1 align="center">Mixture-of-Diffusers for SDXL Tiling PipelineπŸ€—</h1>
188
  <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; text-align: center; overflow:hidden;">
189
- <span>This <a href="https://github.com/DEVAIEXP/mixture-of-diffusers-sdxl-tiling">project</a> implements a SDXL tiling pipeline based on the original project: <a href='https://github.com/albarji/mixture-of-diffusers'>Mixture-of-Diffusers</a>. For more information, see the:
190
  <a href="https://arxiv.org/pdf/2302.02412">πŸ“œ paper </a>
191
- </div>
192
  """
193
 
194
  tips = """
@@ -212,102 +299,67 @@ about = """
212
  If you have any questions or suggestions, feel free to send your question to <b>[email protected]</b>.
213
  """
214
 
215
- with gr.Blocks(css=css) as app:
216
- gr.Markdown(title)
217
  with gr.Row():
218
  with gr.Column(scale=7):
219
  generate_button = gr.Button("Generate")
220
  with gr.Row():
221
  with gr.Column(scale=1):
222
  gr.Markdown("### Left region")
223
- left_prompt = gr.Textbox(lines=4,
224
- label="Prompt for left side of the image")
225
- left_gs = gr.Slider(minimum=0,
226
- maximum=15,
227
- value=7,
228
- step=1,
229
- label="Left CFG scale")
230
  with gr.Column(scale=1):
231
  gr.Markdown("### Center region")
232
- center_prompt = gr.Textbox(lines=4,
233
- label="Prompt for the center of the image")
234
- center_gs = gr.Slider(minimum=0,
235
- maximum=15,
236
- value=7,
237
- step=1,
238
- label="Center CFG scale")
239
  with gr.Column(scale=1):
240
  gr.Markdown("### Right region")
241
- right_prompt = gr.Textbox(lines=4,
242
- label="Prompt for the right side of the image")
243
- right_gs = gr.Slider(minimum=0,
244
- maximum=15,
245
- value=7,
246
- step=1,
247
- label="Right CFG scale")
248
  with gr.Row():
249
- negative_prompt = gr.Textbox(lines=2,
250
- label="Negative prompt for the image",
251
- value="nsfw, lowres, bad anatomy, bad hands, duplicate, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry")
 
 
252
  with gr.Row():
253
  result = gr.Image(
254
  label="Generated Image",
255
- show_label=True,
256
  format="png",
257
  interactive=False,
258
  # allow_preview=True,
259
  # preview=True,
260
  scale=1,
261
-
262
  )
263
  with gr.Column():
264
  gr.Markdown(tips)
265
  with gr.Sidebar(label="Parameters", open=True):
266
  gr.Markdown("### General parameters")
267
  with gr.Row():
268
- height = gr.Slider(label="Height",
269
- value=1024,
270
- step=8,
271
- visible=True,
272
- minimum=512,
273
- maximum=1024)
274
- width = gr.Slider(label="Width",
275
- value=1280,
276
- step=8,
277
- visible=True,
278
- minimum=512,
279
- maximum=3840)
280
- overlap = gr.Slider(minimum=0,
281
- maximum=512,
282
- value=128,
283
- step=8,
284
- label="Tile Overlap")
285
  max_tile_size = gr.Dropdown(label="Max. Tile Size", choices=[1024, 1280], value=1280)
286
- calc_tile = gr.Button("Calculate Tile Size")
287
- with gr.Row():
288
- tile_height = gr.Textbox(label="Tile height", value=1024, interactive=False)
289
  tile_width = gr.Textbox(label="Tile width", value=1024, interactive=False)
290
  with gr.Row():
291
  new_target_height = gr.Textbox(label="New image height", value=1024, interactive=False)
292
  new_target_width = gr.Textbox(label="New image width", value=1024, interactive=False)
293
  with gr.Row():
294
- steps = gr.Slider(minimum=1,
295
- maximum=50,
296
- value=30,
297
- step=1,
298
- label="Inference steps")
299
-
300
- generation_seed = gr.Slider(label="Seed",
301
- minimum=0,
302
- maximum=MAX_SEED,
303
- step=1,
304
- value=0)
305
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
306
  with gr.Row():
 
307
  scheduler = gr.Dropdown(
308
- label="Schedulers",
309
- choices=SCHEDULERS,
310
- value=SCHEDULERS[0],
311
  )
312
  with gr.Row():
313
  gr.Examples(
@@ -317,81 +369,114 @@ with gr.Blocks(css=css) as app:
317
  "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, collapsing buildings, rubble streets, battle-damaged suit, determined expression, distant explosions, cinematic composition, realistic rendering. Focus: Captain America.",
318
  "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, shattered buildings, burning debris, ground trembling, Asgardian armor, cinematic photography, realistic details. Focus: Thor.",
319
  negative_prompt.value,
320
- 5, 5, 5,
 
 
321
  160,
322
  30,
323
  619517442,
324
- "UniPCMultistepScheduler",
325
  1024,
326
  1280,
327
- 1024,
328
  3840,
329
- 1024
 
330
  ],
331
  [
332
  "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
333
  "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
334
  "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
335
  negative_prompt.value,
336
- 7, 7, 7,
 
 
337
  256,
338
  30,
339
  358867853,
340
- "DPMSolverMultistepScheduler-Karras-SDE",
341
  1024,
342
  1280,
343
- 1024,
344
  3840,
345
- 1280
 
346
  ],
347
  [
348
  "Abstract decorative illustration, by joan miro and gustav klimt and marlina vera and loish, elegant, intricate, highly detailed, smooth, sharp focus, vibrant colors, artstation, stunning masterpiece",
349
  "Abstract decorative illustration, by joan miro and gustav klimt and marlina vera and loish, elegant, intricate, highly detailed, smooth, sharp focus, vibrant colors, artstation, stunning masterpiece",
350
  "Abstract decorative illustration, by joan miro and gustav klimt and marlina vera and loish, elegant, intricate, highly detailed, smooth, sharp focus, vibrant colors, artstation, stunning masterpiece",
351
  negative_prompt.value,
352
- 7, 7, 7,
 
 
353
  128,
354
  30,
355
  580541206,
356
- "LMSDiscreteScheduler",
357
  1024,
358
  768,
359
- 1024,
360
  2048,
361
- 1280
 
362
  ],
363
  [
364
  "Magical diagrams and runes written with chalk on a blackboard, elegant, intricate, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
365
  "Magical diagrams and runes written with chalk on a blackboard, elegant, intricate, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
366
  "Magical diagrams and runes written with chalk on a blackboard, elegant, intricate, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
367
  negative_prompt.value,
368
- 9, 9, 9,
 
 
369
  128,
370
  30,
371
  12591765619,
372
- "LMSDiscreteScheduler",
373
  1024,
374
  768,
375
- 1024,
376
  2048,
377
- 1280
378
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  ],
380
- inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap, steps, generation_seed, scheduler, tile_height, tile_width, height, width, max_tile_size],
381
  fn=run_for_examples,
382
  outputs=result,
383
- cache_examples=True
384
  )
385
-
386
- event_calc_tile_size={"fn": do_calc_tile, "inputs":[height, width, overlap, max_tile_size], "outputs":[tile_height, tile_width, new_target_height, new_target_width]}
 
 
 
 
387
  calc_tile.click(**event_calc_tile_size)
388
-
389
  generate_button.click(
390
  fn=clear_result,
391
  inputs=None,
392
  outputs=result,
393
- ).then(**event_calc_tile_size
394
- ).then(
395
  fn=randomize_seed_fn,
396
  inputs=[generation_seed, randomize_seed],
397
  outputs=generation_seed,
@@ -399,7 +484,24 @@ with gr.Blocks(css=css) as app:
399
  api_name=False,
400
  ).then(
401
  fn=predict,
402
- inputs=[left_prompt, center_prompt, right_prompt, negative_prompt, left_gs, center_gs, right_gs, overlap, steps, generation_seed, scheduler, tile_height, tile_width, new_target_height, new_target_width],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  outputs=result,
404
  )
405
  gr.Markdown(about)
 
1
  import random
2
+
3
  import gradio as gr
4
  import numpy as np
5
  import spaces
6
  import torch
7
+ from pipeline.mixture_tiling_sdxl import StableDiffusionXLTilingPipeline
8
+ from pipeline.util import SAMPLERS, create_hdr_effect, select_scheduler
9
+
10
  from diffusers import AutoencoderKL
11
+
12
 
13
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
16
+
17
+ model_id = "stablediffusionapi/yamermix-v8-vae"
18
  pipe = StableDiffusionXLTilingPipeline.from_pretrained(
19
  model_id,
20
  torch_dtype=torch.float16,
21
  vae=vae,
22
+ use_safetensors=False, # for yammermix
23
+ # variant="fp16",
24
  ).to("cuda")
25
 
26
+ #pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
27
  pipe.enable_vae_tiling()
28
  pipe.enable_vae_slicing()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # region functions
32
  @spaces.GPU
33
+ def predict(
34
+ left_prompt,
35
+ center_prompt,
36
+ right_prompt,
37
+ negative_prompt,
38
+ left_gs,
39
+ center_gs,
40
+ right_gs,
41
+ overlap_pixels,
42
+ steps,
43
+ generation_seed,
44
+ scheduler,
45
+ tile_height,
46
+ tile_width,
47
+ target_height,
48
+ target_width,
49
+ hdr,
50
+ progress=gr.Progress(track_tqdm=True),
51
+ ):
52
  global pipe
53
+
54
  # Set selected scheduler
55
  print(f"Using scheduler: {scheduler}...")
56
+ pipe.scheduler = select_scheduler(pipe, scheduler)
57
 
58
  # Set seed
59
  generator = torch.Generator("cuda").manual_seed(generation_seed)
60
+
61
  target_height = int(target_height)
62
  target_width = int(target_width)
63
  tile_height = int(tile_height)
64
  tile_width = int(tile_width)
65
+
66
  # Mixture of Diffusers generation
67
  image = pipe(
68
  prompt=[
69
  [
70
  left_prompt,
71
  center_prompt,
72
+ right_prompt,
73
  ]
74
  ],
75
  negative_prompt=negative_prompt,
76
  tile_height=tile_height,
77
  tile_width=tile_width,
78
  tile_row_overlap=0,
79
+ tile_col_overlap=overlap_pixels,
80
+ guidance_scale_tiles=[[left_gs, center_gs, right_gs]],
81
  height=target_height,
82
+ width=target_width,
83
  generator=generator,
84
  num_inference_steps=steps,
85
  )["images"][0]
86
 
87
+ image = create_hdr_effect(image, hdr)
88
  return image
89
 
90
+
91
  def calc_tile_size(target_height, target_width, overlap_pixels, max_tile_width_size=1280):
92
+ num_cols = 3
93
+ num_rows = 1
94
+ min_tile_dimension = 8
95
+ reduction_step = 8
96
+ max_tile_height_size = 1024
97
  best_tile_width = 0
98
  best_tile_height = 0
99
  best_adjusted_target_width = 0
 
106
 
107
  while tile_width >= min_tile_dimension:
108
  horizontal_borders = num_cols - 1
109
+ total_horizontal_overlap_pixels = overlap_pixels * horizontal_borders
110
  adjusted_target_width = tile_width * num_cols - total_horizontal_overlap_pixels
111
 
112
  vertical_borders = num_rows - 1
113
+ total_vertical_overlap_pixels = overlap_pixels * vertical_borders
114
  adjusted_target_height = tile_height * num_rows - total_vertical_overlap_pixels
115
 
116
  if tile_width <= max_tile_width_size and adjusted_target_width <= target_width:
 
128
 
129
  while tile_height >= min_tile_dimension:
130
  horizontal_borders = num_cols - 1
131
+ total_horizontal_overlap_pixels = overlap_pixels * horizontal_borders
132
  adjusted_target_width = tile_width * num_cols - total_horizontal_overlap_pixels
133
 
134
  vertical_borders = num_rows - 1
135
+ total_vertical_overlap_pixels = overlap_pixels * vertical_borders
136
  adjusted_target_height = tile_height * num_rows - total_vertical_overlap_pixels
137
+
138
  if tile_height <= max_tile_height_size and adjusted_target_height <= target_height:
139
+ if adjusted_target_height > best_adjusted_target_height:
140
  best_tile_height = tile_height
141
  best_adjusted_target_height = adjusted_target_height
142
 
 
147
  tile_width = best_tile_width
148
  tile_height = best_tile_height
149
 
150
+ print("--- TILE SIZE CALCULATED VALUES ---")
151
  print(f"Overlap pixels (requested): {overlap_pixels}")
152
  print(f"Tile Height (divisible by 8, max {max_tile_height_size}): {tile_height}")
153
  print(f"Tile Width (divisible by 8, max {max_tile_width_size}): {tile_width}")
 
160
 
161
  return new_target_height, new_target_width, tile_height, tile_width
162
 
163
+
164
+ def do_calc_tile(target_height, target_width, overlap_pixels, max_tile_size):
165
+ new_target_height, new_target_width, tile_height, tile_width = calc_tile_size(
166
+ target_height, target_width, overlap_pixels, max_tile_size
167
+ )
168
+ return (
169
+ gr.update(value=tile_height),
170
+ gr.update(value=tile_width),
171
+ gr.update(value=new_target_height),
172
+ gr.update(value=new_target_width),
173
+ )
174
+
175
 
176
  def clear_result():
177
  return gr.update(value=None)
178
 
179
+
180
+ def run_for_examples(
181
+ left_prompt,
182
+ center_prompt,
183
+ right_prompt,
184
+ negative_prompt,
185
+ left_gs,
186
+ center_gs,
187
+ right_gs,
188
+ overlap_pixels,
189
+ steps,
190
+ generation_seed,
191
+ scheduler,
192
+ tile_height,
193
+ tile_width,
194
+ target_height,
195
+ target_width,
196
+ max_tile_width,
197
+ hdr,
198
+ ):
199
+ return predict(
200
+ left_prompt,
201
+ center_prompt,
202
+ right_prompt,
203
+ negative_prompt,
204
+ left_gs,
205
+ center_gs,
206
+ right_gs,
207
+ overlap_pixels,
208
+ steps,
209
+ generation_seed,
210
+ scheduler,
211
+ tile_height,
212
+ tile_width,
213
+ target_height,
214
+ target_width,
215
+ hdr,
216
+ )
217
+
218
 
219
  def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
220
  if randomize_seed:
221
  generation_seed = random.randint(0, MAX_SEED)
222
  return generation_seed
223
 
224
+
225
  css = """
226
+ body {
227
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
228
+ margin: 0;
229
+ padding: 0;
230
+ }
231
+ .gradio-container {
232
+ border-radius: 15px;
233
+ padding: 30px 40px;
234
+ box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3);
235
+ margin: 40px 340px;
236
+ }
237
+ .gradio-container h1 {
238
+ text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
239
+ }
240
+ .fillable {
241
+ width: 100% !important;
242
  max-width: unset !important;
243
  }
244
+ #examples_container {
245
+ margin: auto;
246
+ width: 90%;
247
+ }
248
+ #examples_row {
249
+ justify-content: center;
250
+ }
251
+ #tips_row{
252
+ padding-left: 20px;
253
+ }
254
+ .sidebar {
255
+ border-radius: 10px;
256
+ padding: 10px;
257
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
258
+ }
259
+ .sidebar .toggle-button {
260
+ background: linear-gradient(90deg, #fbbf24, #fcd34d) !important;
261
+ border: none;
262
+ padding: 12px 24px;
263
+ text-transform: uppercase;
264
+ font-weight: bold;
265
+ letter-spacing: 1px;
266
+ border-radius: 5px;
267
+ cursor: pointer;
268
+ transition: transform 0.2s ease-in-out;
269
+ }
270
+ .toggle-button:hover {
271
+ transform: scale(1.05);
272
+ }
273
  """
274
+ title = """<h1 align="center">Mixture-of-Diffusers for SDXL Tiling PipelineπŸ€—</h1>
275
  <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; text-align: center; overflow:hidden;">
276
+ <span>This <a href="https://github.com/DEVAIEXP/mixture-of-diffusers-sdxl-tiling">project</a> implements a SDXL tiling pipeline based on the original project: <a href='https://github.com/albarji/mixture-of-diffusers'>Mixture-of-Diffusers</a>. For more information, see the:
277
  <a href="https://arxiv.org/pdf/2302.02412">πŸ“œ paper </a>
278
+ </div>
279
  """
280
 
281
  tips = """
 
299
  If you have any questions or suggestions, feel free to send your question to <b>[email protected]</b>.
300
  """
301
 
302
+ with gr.Blocks(css=css, theme=gr.themes.Citrus()) as app:
303
+ gr.Markdown(title)
304
  with gr.Row():
305
  with gr.Column(scale=7):
306
  generate_button = gr.Button("Generate")
307
  with gr.Row():
308
  with gr.Column(scale=1):
309
  gr.Markdown("### Left region")
310
+ left_prompt = gr.Textbox(lines=4, label="Prompt for left side of the image")
311
+ left_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Left CFG scale")
 
 
 
 
 
312
  with gr.Column(scale=1):
313
  gr.Markdown("### Center region")
314
+ center_prompt = gr.Textbox(lines=4, label="Prompt for the center of the image")
315
+ center_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Center CFG scale")
 
 
 
 
 
316
  with gr.Column(scale=1):
317
  gr.Markdown("### Right region")
318
+ right_prompt = gr.Textbox(lines=4, label="Prompt for the right side of the image")
319
+ right_gs = gr.Slider(minimum=0, maximum=15, value=7, step=1, label="Right CFG scale")
 
 
 
 
 
320
  with gr.Row():
321
+ negative_prompt = gr.Textbox(
322
+ lines=2,
323
+ label="Negative prompt for the image",
324
+ value="nsfw, lowres, bad anatomy, bad hands, duplicate, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry",
325
+ )
326
  with gr.Row():
327
  result = gr.Image(
328
  label="Generated Image",
329
+ show_label=True,
330
  format="png",
331
  interactive=False,
332
  # allow_preview=True,
333
  # preview=True,
334
  scale=1,
 
335
  )
336
  with gr.Column():
337
  gr.Markdown(tips)
338
  with gr.Sidebar(label="Parameters", open=True):
339
  gr.Markdown("### General parameters")
340
  with gr.Row():
341
+ height = gr.Slider(label="Height", value=1024, step=8, visible=True, minimum=512, maximum=1024)
342
+ width = gr.Slider(label="Width", value=1280, step=8, visible=True, minimum=512, maximum=3840)
343
+ overlap = gr.Slider(minimum=0, maximum=512, value=128, step=8, label="Tile Overlap")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  max_tile_size = gr.Dropdown(label="Max. Tile Size", choices=[1024, 1280], value=1280)
345
+ calc_tile = gr.Button("Calculate Tile Size")
346
+ with gr.Row():
347
+ tile_height = gr.Textbox(label="Tile height", value=1024, interactive=False)
348
  tile_width = gr.Textbox(label="Tile width", value=1024, interactive=False)
349
  with gr.Row():
350
  new_target_height = gr.Textbox(label="New image height", value=1024, interactive=False)
351
  new_target_width = gr.Textbox(label="New image width", value=1024, interactive=False)
352
  with gr.Row():
353
+ steps = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference steps")
354
+
355
+ generation_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
356
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
 
 
 
 
 
 
 
 
357
  with gr.Row():
358
+ hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
359
  scheduler = gr.Dropdown(
360
+ label="Sampler",
361
+ choices=list(SAMPLERS.keys()),
362
+ value="UniPC",
363
  )
364
  with gr.Row():
365
  gr.Examples(
 
369
  "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, collapsing buildings, rubble streets, battle-damaged suit, determined expression, distant explosions, cinematic composition, realistic rendering. Focus: Captain America.",
370
  "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, shattered buildings, burning debris, ground trembling, Asgardian armor, cinematic photography, realistic details. Focus: Thor.",
371
  negative_prompt.value,
372
+ 5,
373
+ 5,
374
+ 5,
375
  160,
376
  30,
377
  619517442,
378
+ "UniPC",
379
  1024,
380
  1280,
381
+ 1024,
382
  3840,
383
+ 1024,
384
+ 0,
385
  ],
386
  [
387
  "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
388
  "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
389
  "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
390
  negative_prompt.value,
391
+ 7,
392
+ 7,
393
+ 7,
394
  256,
395
  30,
396
  358867853,
397
+ "DPM++ 3M Karras",
398
  1024,
399
  1280,
400
+ 1024,
401
  3840,
402
+ 1280,
403
+ 0,
404
  ],
405
  [
406
  "Abstract decorative illustration, by joan miro and gustav klimt and marlina vera and loish, elegant, intricate, highly detailed, smooth, sharp focus, vibrant colors, artstation, stunning masterpiece",
407
  "Abstract decorative illustration, by joan miro and gustav klimt and marlina vera and loish, elegant, intricate, highly detailed, smooth, sharp focus, vibrant colors, artstation, stunning masterpiece",
408
  "Abstract decorative illustration, by joan miro and gustav klimt and marlina vera and loish, elegant, intricate, highly detailed, smooth, sharp focus, vibrant colors, artstation, stunning masterpiece",
409
  negative_prompt.value,
410
+ 7,
411
+ 7,
412
+ 7,
413
  128,
414
  30,
415
  580541206,
416
+ "LMS",
417
  1024,
418
  768,
419
+ 1024,
420
  2048,
421
+ 1280,
422
+ 0,
423
  ],
424
  [
425
  "Magical diagrams and runes written with chalk on a blackboard, elegant, intricate, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
426
  "Magical diagrams and runes written with chalk on a blackboard, elegant, intricate, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
427
  "Magical diagrams and runes written with chalk on a blackboard, elegant, intricate, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
428
  negative_prompt.value,
429
+ 9,
430
+ 9,
431
+ 9,
432
  128,
433
  30,
434
  12591765619,
435
+ "LMS",
436
  1024,
437
  768,
438
+ 1024,
439
  2048,
440
+ 1280,
441
+ 0,
442
+ ],
443
+ ],
444
+ inputs=[
445
+ left_prompt,
446
+ center_prompt,
447
+ right_prompt,
448
+ negative_prompt,
449
+ left_gs,
450
+ center_gs,
451
+ right_gs,
452
+ overlap,
453
+ steps,
454
+ generation_seed,
455
+ scheduler,
456
+ tile_height,
457
+ tile_width,
458
+ height,
459
+ width,
460
+ max_tile_size,
461
+ hdr,
462
  ],
 
463
  fn=run_for_examples,
464
  outputs=result,
465
+ cache_examples=True,
466
  )
467
+
468
+ event_calc_tile_size = {
469
+ "fn": do_calc_tile,
470
+ "inputs": [height, width, overlap, max_tile_size],
471
+ "outputs": [tile_height, tile_width, new_target_height, new_target_width],
472
+ }
473
  calc_tile.click(**event_calc_tile_size)
474
+
475
  generate_button.click(
476
  fn=clear_result,
477
  inputs=None,
478
  outputs=result,
479
+ ).then(**event_calc_tile_size).then(
 
480
  fn=randomize_seed_fn,
481
  inputs=[generation_seed, randomize_seed],
482
  outputs=generation_seed,
 
484
  api_name=False,
485
  ).then(
486
  fn=predict,
487
+ inputs=[
488
+ left_prompt,
489
+ center_prompt,
490
+ right_prompt,
491
+ negative_prompt,
492
+ left_gs,
493
+ center_gs,
494
+ right_gs,
495
+ overlap,
496
+ steps,
497
+ generation_seed,
498
+ scheduler,
499
+ tile_height,
500
+ tile_width,
501
+ new_target_height,
502
+ new_target_width,
503
+ hdr,
504
+ ],
505
  outputs=result,
506
  )
507
  gr.Markdown(about)
mixture_tiling_sdxl.py β†’ pipeline/mixture_tiling_sdxl.py RENAMED
@@ -1,4 +1,4 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -1067,32 +1067,32 @@ class StableDiffusionXLTilingPipeline(
1067
  text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1068
  else:
1069
  text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1070
- add_time_ids = self._get_add_time_ids(
1071
- original_size,
1072
- crops_coords_top_left[row][col],
1073
- target_size,
 
 
 
 
 
 
 
 
1074
  dtype=prompt_embeds.dtype,
1075
  text_encoder_projection_dim=text_encoder_projection_dim,
1076
  )
1077
- if negative_original_size is not None and negative_target_size is not None:
1078
- negative_add_time_ids = self._get_add_time_ids(
1079
- negative_original_size,
1080
- negative_crops_coords_top_left[row][col],
1081
- negative_target_size,
1082
- dtype=prompt_embeds.dtype,
1083
- text_encoder_projection_dim=text_encoder_projection_dim,
1084
- )
1085
- else:
1086
- negative_add_time_ids = add_time_ids
1087
 
1088
- if self.do_classifier_free_guidance:
1089
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1090
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1091
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1092
 
1093
- prompt_embeds = prompt_embeds.to(device)
1094
- add_text_embeds = add_text_embeds.to(device)
1095
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1096
  addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
1097
  embeddings_and_added_time.append(addition_embed_type_row)
1098
 
 
1
+ # Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
1067
  text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1068
  else:
1069
  text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1070
+ add_time_ids = self._get_add_time_ids(
1071
+ original_size,
1072
+ crops_coords_top_left[row][col],
1073
+ target_size,
1074
+ dtype=prompt_embeds.dtype,
1075
+ text_encoder_projection_dim=text_encoder_projection_dim,
1076
+ )
1077
+ if negative_original_size is not None and negative_target_size is not None:
1078
+ negative_add_time_ids = self._get_add_time_ids(
1079
+ negative_original_size,
1080
+ negative_crops_coords_top_left[row][col],
1081
+ negative_target_size,
1082
  dtype=prompt_embeds.dtype,
1083
  text_encoder_projection_dim=text_encoder_projection_dim,
1084
  )
1085
+ else:
1086
+ negative_add_time_ids = add_time_ids
 
 
 
 
 
 
 
 
1087
 
1088
+ if self.do_classifier_free_guidance:
1089
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1090
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1091
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1092
 
1093
+ prompt_embeds = prompt_embeds.to(device)
1094
+ add_text_embeds = add_text_embeds.to(device)
1095
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1096
  addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
1097
  embeddings_and_added_time.append(addition_embed_type_row)
1098
 
pipeline/util.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The DEVAIEXP Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import gc
17
+ import cv2
18
+ import numpy as np
19
+ import torch
20
+ from PIL import Image
21
+
22
+
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+ SAMPLERS = {
25
+ "DDIM": ("DDIMScheduler", {}),
26
+ "DDIM trailing": ("DDIMScheduler", {"timestep_spacing": "trailing"}),
27
+ "DDPM": ("DDPMScheduler", {}),
28
+ "DEIS": ("DEISMultistepScheduler", {}),
29
+ "Heun": ("HeunDiscreteScheduler", {}),
30
+ "Heun Karras": ("HeunDiscreteScheduler", {"use_karras_sigmas": True}),
31
+ "Euler": ("EulerDiscreteScheduler", {}),
32
+ "Euler trailing": ("EulerDiscreteScheduler", {"timestep_spacing": "trailing", "prediction_type": "sample"}),
33
+ "Euler Ancestral": ("EulerAncestralDiscreteScheduler", {}),
34
+ "Euler Ancestral trailing": ("EulerAncestralDiscreteScheduler", {"timestep_spacing": "trailing"}),
35
+ "DPM++ 1S": ("DPMSolverMultistepScheduler", {"solver_order": 1}),
36
+ "DPM++ 1S Karras": ("DPMSolverMultistepScheduler", {"solver_order": 1, "use_karras_sigmas": True}),
37
+ "DPM++ 2S": ("DPMSolverSinglestepScheduler", {"use_karras_sigmas": False}),
38
+ "DPM++ 2S Karras": ("DPMSolverSinglestepScheduler", {"use_karras_sigmas": True}),
39
+ "DPM++ 2M": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": False}),
40
+ "DPM++ 2M Karras": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": True}),
41
+ "DPM++ 2M SDE": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
42
+ "DPM++ 2M SDE Karras": (
43
+ "DPMSolverMultistepScheduler",
44
+ {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"},
45
+ ),
46
+ "DPM++ 3M": ("DPMSolverMultistepScheduler", {"solver_order": 3}),
47
+ "DPM++ 3M Karras": ("DPMSolverMultistepScheduler", {"solver_order": 3, "use_karras_sigmas": True}),
48
+ "DPM++ SDE": ("DPMSolverSDEScheduler", {"use_karras_sigmas": False}),
49
+ "DPM++ SDE Karras": ("DPMSolverSDEScheduler", {"use_karras_sigmas": True}),
50
+ "DPM2": ("KDPM2DiscreteScheduler", {}),
51
+ "DPM2 Karras": ("KDPM2DiscreteScheduler", {"use_karras_sigmas": True}),
52
+ "DPM2 Ancestral": ("KDPM2AncestralDiscreteScheduler", {}),
53
+ "DPM2 Ancestral Karras": ("KDPM2AncestralDiscreteScheduler", {"use_karras_sigmas": True}),
54
+ "LMS": ("LMSDiscreteScheduler", {}),
55
+ "LMS Karras": ("LMSDiscreteScheduler", {"use_karras_sigmas": True}),
56
+ "UniPC": ("UniPCMultistepScheduler", {}),
57
+ "UniPC Karras": ("UniPCMultistepScheduler", {"use_karras_sigmas": True}),
58
+ "PNDM": ("PNDMScheduler", {}),
59
+ "Euler EDM": ("EDMEulerScheduler", {}),
60
+ "Euler EDM Karras": ("EDMEulerScheduler", {"use_karras_sigmas": True}),
61
+ "DPM++ 2M EDM": (
62
+ "EDMDPMSolverMultistepScheduler",
63
+ {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"},
64
+ ),
65
+ "DPM++ 2M EDM Karras": (
66
+ "EDMDPMSolverMultistepScheduler",
67
+ {
68
+ "use_karras_sigmas": True,
69
+ "solver_order": 2,
70
+ "solver_type": "midpoint",
71
+ "final_sigmas_type": "zero",
72
+ "algorithm_type": "dpmsolver++",
73
+ },
74
+ ),
75
+ "DPM++ 2M Lu": ("DPMSolverMultistepScheduler", {"use_lu_lambdas": True}),
76
+ "DPM++ 2M Ef": ("DPMSolverMultistepScheduler", {"euler_at_final": True}),
77
+ "DPM++ 2M SDE Lu": ("DPMSolverMultistepScheduler", {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
78
+ "DPM++ 2M SDE Ef": ("DPMSolverMultistepScheduler", {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),
79
+ "LCM": ("LCMScheduler", {}),
80
+ "LCM trailing": ("LCMScheduler", {"timestep_spacing": "trailing"}),
81
+ "TCD": ("TCDScheduler", {}),
82
+ "TCD trailing": ("TCDScheduler", {"timestep_spacing": "trailing"}),
83
+ }
84
+
85
+ def select_scheduler(pipe, selected_sampler):
86
+ import diffusers
87
+
88
+ scheduler_class_name, add_kwargs = SAMPLERS[selected_sampler]
89
+ config = pipe.scheduler.config
90
+ scheduler = getattr(diffusers, scheduler_class_name)
91
+ if selected_sampler in ("LCM", "LCM trailing"):
92
+ config = {
93
+ x: config[x] for x in config if x not in ("skip_prk_steps", "interpolation_type", "use_karras_sigmas")
94
+ }
95
+ elif selected_sampler in ("TCD", "TCD trailing"):
96
+ config = {x: config[x] for x in config if x not in ("skip_prk_steps")}
97
+
98
+ return scheduler.from_config(config, **add_kwargs)
99
+
100
+
101
+ # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
102
+ def create_hdr_effect(original_image, hdr):
103
+ """
104
+ Applies an HDR (High Dynamic Range) effect to an image based on the specified intensity.
105
+
106
+ Args:
107
+ original_image (PIL.Image.Image): The original image to which the HDR effect will be applied.
108
+ hdr (float): The intensity of the HDR effect, ranging from 0 (no effect) to 1 (maximum effect).
109
+
110
+ Returns:
111
+ PIL.Image.Image: The image with the HDR effect applied.
112
+ """
113
+ if hdr == 0:
114
+ return original_image # No effect applied if hdr is 0
115
+
116
+ # Convert the PIL image to a NumPy array in BGR format (OpenCV format)
117
+ cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
118
+
119
+ # Define scaling factors for creating multiple exposures
120
+ factors = [
121
+ 1.0 - 0.9 * hdr,
122
+ 1.0 - 0.7 * hdr,
123
+ 1.0 - 0.45 * hdr,
124
+ 1.0 - 0.25 * hdr,
125
+ 1.0,
126
+ 1.0 + 0.2 * hdr,
127
+ 1.0 + 0.4 * hdr,
128
+ 1.0 + 0.6 * hdr,
129
+ 1.0 + 0.8 * hdr,
130
+ ]
131
+
132
+ # Generate multiple exposure images by scaling the original image
133
+ images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
134
+
135
+ # Merge the images using the Mertens algorithm to create an HDR effect
136
+ merge_mertens = cv2.createMergeMertens()
137
+ hdr_image = merge_mertens.process(images)
138
+
139
+ # Convert the HDR image to 8-bit format (0-255 range)
140
+ hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype("uint8")
141
+
142
+ torch_gc()
143
+
144
+ # Convert the image back to RGB format and return as a PIL image
145
+ return Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
146
+
147
+
148
+ def torch_gc():
149
+ gc.collect()
150
+ if torch.cuda.is_available():
151
+ with torch.cuda.device("cuda"):
152
+ torch.cuda.empty_cache()
153
+ torch.cuda.ipc_collect()
154
+
155
+ def quantize_8bit(unet):
156
+ if unet is None:
157
+ return
158
+
159
+ from peft.tuners.tuners_utils import BaseTunerLayer
160
+
161
+ dtype = unet.dtype
162
+ unet.to(torch.float8_e4m3fn)
163
+ for module in unet.modules(): # revert lora modules to prevent errors with fp8
164
+ if isinstance(module, BaseTunerLayer):
165
+ module.to(dtype)
166
+
167
+ if hasattr(unet, "encoder_hid_proj"): # revert ip adapter modules to prevent errors with fp8
168
+ if unet.encoder_hid_proj is not None:
169
+ for module in unet.encoder_hid_proj.modules():
170
+ module.to(dtype)
171
+ torch_gc()
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
  torch
 
2
  spaces
3
  scipy
4
- gradio==5.15.0
 
5
  numpy==1.26.4
6
  transformers
7
  accelerate
 
1
  torch
2
+ peft
3
  spaces
4
  scipy
5
+ gradio==5.20.1
6
+ opencv-python
7
  numpy==1.26.4
8
  transformers
9
  accelerate