Freak-ppa commited on
Commit
9754000
·
verified ·
1 Parent(s): 1c5a965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -228,16 +228,25 @@ def run_rmbg(img, sigma=0.0):
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
230
  @torch.inference_mode()
231
- def merge_alpha(img):
 
 
 
 
 
 
232
  H, W, C = img.shape
233
  print(f"img.shape: {img.shape}")
 
234
  if C == 3:
235
  return img
236
- else:
237
  rgb = img[:, :, :3]
238
- alpha = img[:, :, 3] / 255.0
239
- result = 127 + (rgb.astype(np.float32) - 127 + sigma) * alpha[:, :, np.newaxis]
240
  return np.clip(result, 0, 255).astype(np.uint8)
 
 
241
 
242
 
243
  @torch.inference_mode()
@@ -268,6 +277,7 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
268
 
269
  rng = torch.Generator(device=device).manual_seed(int(seed))
270
 
 
271
  fg = resize_and_center_crop(input_fg, image_width, image_height)
272
 
273
  concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
@@ -289,7 +299,8 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
289
  cross_attention_kwargs={'concat_conds': concat_conds},
290
  ).images.to(vae.dtype) / vae.config.scaling_factor
291
  else:
292
- bg = resize_and_center_crop(input_bg, image_width, image_height)
 
293
  bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
294
  bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
295
  latents = i2i_pipe(
@@ -391,14 +402,12 @@ class BGSource(Enum):
391
  block = gr.Blocks().queue()
392
  with block:
393
  with gr.Row():
394
- gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
395
- with gr.Row():
396
- gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
397
  with gr.Row():
398
  with gr.Column():
399
  with gr.Row():
400
- input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480)
401
- output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
402
  prompt = gr.Textbox(label="Prompt")
403
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
404
  value=BGSource.NONE.value,
@@ -413,8 +422,8 @@ with block:
413
  seed = gr.Number(label="Seed", value=12345, precision=0)
414
 
415
  with gr.Row():
416
- image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
417
- image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
418
 
419
  with gr.Accordion("Advanced options", open=False):
420
  steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
@@ -428,15 +437,7 @@ with block:
428
  result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
429
  with gr.Row():
430
  dummy_image_for_outputs = gr.Image(visible=False, label='Result')
431
- gr.Examples(
432
- fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
433
- examples=db_examples.foreground_conditioned_examples,
434
- inputs=[
435
- input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
436
- ],
437
- outputs=[result_gallery, output_bg],
438
- run_on_click=True, examples_per_page=1024
439
- )
440
  ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
441
  relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
442
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
 
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
230
  @torch.inference_mode()
231
+ def merge_alpha(img, sigma=0.0):
232
+ if img is None:
233
+ return None
234
+
235
+ if len(img.shape) == 2:
236
+ img = np.stack((img,)*3, axis=-1)
237
+
238
  H, W, C = img.shape
239
  print(f"img.shape: {img.shape}")
240
+
241
  if C == 3:
242
  return img
243
+ elif C == 4:
244
  rgb = img[:, :, :3]
245
+ alpha = img[:, :, 3]
246
+ result = 127 + (rgb.astype(np.float32) - 127 + sigma) * alpha[:, :, np.newaxis]
247
  return np.clip(result, 0, 255).astype(np.uint8)
248
+ else:
249
+ raise ValueError(f"Unexpected number of channels: {C}")
250
 
251
 
252
  @torch.inference_mode()
 
277
 
278
  rng = torch.Generator(device=device).manual_seed(int(seed))
279
 
280
+ #fg = input_fg
281
  fg = resize_and_center_crop(input_fg, image_width, image_height)
282
 
283
  concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
 
299
  cross_attention_kwargs={'concat_conds': concat_conds},
300
  ).images.to(vae.dtype) / vae.config.scaling_factor
301
  else:
302
+ #bg = input_bg
303
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
304
  bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
305
  bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
306
  latents = i2i_pipe(
 
402
  block = gr.Blocks().queue()
403
  with block:
404
  with gr.Row():
405
+ gr.Markdown("## wow dub")
 
 
406
  with gr.Row():
407
  with gr.Column():
408
  with gr.Row():
409
+ input_fg = gr.Image(sources='upload', type="numpy", label="Image", image_mode=['RGBA', 'RGB'])
410
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground")
411
  prompt = gr.Textbox(label="Prompt")
412
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
413
  value=BGSource.NONE.value,
 
422
  seed = gr.Number(label="Seed", value=12345, precision=0)
423
 
424
  with gr.Row():
425
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=512, step=64)
426
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=640, step=64)
427
 
428
  with gr.Accordion("Advanced options", open=False):
429
  steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
 
437
  result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
438
  with gr.Row():
439
  dummy_image_for_outputs = gr.Image(visible=False, label='Result')
440
+
 
 
 
 
 
 
 
 
441
  ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
442
  relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
443
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)