not-lain commited on
Commit
e406805
·
1 Parent(s): 991bda2

refactor inpaint function to return composited image and update main function signature

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -53,13 +53,13 @@ def inpaint(
53
  padding_left=0,
54
  padding_right=0,
55
  prompt="",
56
- progress=gr.Progress(track_tqdm=True),
57
  ):
58
  background, mask = prepare_image_and_mask(
59
  image, padding_top, padding_bottom, padding_left, padding_right
60
  )
61
 
62
- # generator = torch.Generator(device="cuda").manual_seed(42)
 
63
 
64
  result = pipe(
65
  prompt=prompt,
@@ -72,7 +72,9 @@ def inpaint(
72
  ).images[0]
73
 
74
  result = result.convert("RGBA")
75
- return result
 
 
76
 
77
 
78
  def rmbg(image, url):
@@ -92,11 +94,11 @@ def rmbg(image, url):
92
 
93
 
94
  @spaces.GPU
95
- def main(*args, **kwargs):
96
- if len (args) == 2:
97
  return rmbg(*args)
98
- else :
99
- return inpaint(*args, **kwargs)
100
 
101
 
102
  rmbg_tab = gr.Interface(
 
53
  padding_left=0,
54
  padding_right=0,
55
  prompt="",
 
56
  ):
57
  background, mask = prepare_image_and_mask(
58
  image, padding_top, padding_bottom, padding_left, padding_right
59
  )
60
 
61
+ cnet_image = background.copy()
62
+ cnet_image.paste(0, (0, 0), mask)
63
 
64
  result = pipe(
65
  prompt=prompt,
 
72
  ).images[0]
73
 
74
  result = result.convert("RGBA")
75
+ cnet_image.paste(result, (0, 0), mask)
76
+
77
+ return cnet_image
78
 
79
 
80
  def rmbg(image, url):
 
94
 
95
 
96
  @spaces.GPU
97
+ def main(*args, progress=gr.Progress(track_tqdm=True)):
98
+ if len(args) == 2:
99
  return rmbg(*args)
100
+ else:
101
+ return inpaint(*args)
102
 
103
 
104
  rmbg_tab = gr.Interface(