theSure commited on
Commit
069a909
Β·
verified Β·
1 Parent(s): 78aea0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -27
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import io
2
  import os
3
- os.system("pip uninstall -y gradio")
4
- os.system("pip install gradio==3.49.0")
5
  import shutil
6
  import uuid
7
  import torch
@@ -106,69 +104,77 @@ def predict(
106
  scale,
107
  image_paths,
108
  mask_paths
 
109
  ):
110
  global image_path, mask_path
111
  gr.Info(str(f"Set seed = {seed}"))
112
  if image_paths is not None:
113
- input_image["image"] = load_image(image_paths).convert("RGB")
114
- input_image["mask"] = load_image(mask_paths).convert("RGB")
115
 
116
- size1, size2 = input_image["image"].convert("RGB").size
117
-
118
- icc_profile = input_image["image"].info.get('icc_profile')
119
  if icc_profile:
120
  gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
121
  srgb_profile = ImageCms.createProfile("sRGB")
122
  io_handle = io.BytesIO(icc_profile)
123
  src_profile = ImageCms.ImageCmsProfile(io_handle)
124
- input_image["image"] = ImageCms.profileToProfile(input_image["image"], src_profile, srgb_profile)
125
- input_image["image"].info.pop('icc_profile', None)
126
 
127
  if size1 < size2:
128
- input_image["image"] = input_image["image"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
129
  else:
130
- input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
131
 
132
- img = np.array(input_image["image"].convert("RGB"))
133
 
134
  W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
135
  H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
136
 
137
- input_image["image"] = input_image["image"].resize((H, W))
138
- input_image["mask"] = input_image["mask"].resize((H, W))
139
 
140
  if seed == -1:
141
  seed = random.randint(1, 2147483647)
142
  set_seed(random.randint(1, 2147483647))
143
  else:
144
  set_seed(seed)
145
-
146
-
 
 
 
 
 
 
 
 
147
  result = pipe(
148
  prompt=prompt,
149
- control_image=input_image["image"].convert("RGB"),
150
- control_mask=input_image["mask"].convert("RGB"),
151
  width=H,
152
  height=W,
153
  num_inference_steps=ddim_steps,
154
- generator=torch.Generator(device).manual_seed(seed),
155
  guidance_scale=scale,
156
  max_sequence_length=512,
157
  ).images[0]
158
 
159
- mask_np = np.array(input_image["mask"].convert("RGB"))
160
- red = np.array(input_image["image"]).astype("float") * 1
161
  red[:, :, 0] = 180.0
162
  red[:, :, 2] = 0
163
  red[:, :, 1] = 0
164
- result_m = np.array(input_image["image"])
165
  result_m = Image.fromarray(
166
  (
167
  result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
168
  ).astype("uint8")
169
  )
170
 
171
- dict_res = [input_image["image"], input_image["mask"], result_m, result]
172
 
173
  dict_out = [result]
174
  image_path = None
@@ -182,6 +188,7 @@ def infer(
182
  seed,
183
  scale,
184
  removal_prompt,
 
185
  ):
186
  img_path = image_path
187
  msk_path = mask_path
@@ -205,7 +212,9 @@ def process_example(image_paths, mask_paths):
205
  mask_path = mask_paths
206
  return masked_image
207
  custom_css = """
 
208
  .contain { max-width: 1200px !important; }
 
209
  .custom-image {
210
  border: 2px dashed #7e22ce !important;
211
  border-radius: 12px !important;
@@ -215,6 +224,7 @@ custom_css = """
215
  border-color: #9333ea !important;
216
  box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important;
217
  }
 
218
  .btn-primary {
219
  background: linear-gradient(45deg, #7e22ce, #9333ea) !important;
220
  border: none !important;
@@ -227,14 +237,17 @@ custom_css = """
227
  padding: 16px !important;
228
  margin-top: 8px !important;
229
  }
 
230
  #inline-examples .thumbnail {
231
  border-radius: 8px !important;
232
  transition: transform 0.2s ease !important;
233
  }
 
234
  #inline-examples .thumbnail:hover {
235
  transform: scale(1.05);
236
  box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
237
  }
 
238
  .example-title h3 {
239
  margin: 0 0 12px 0 !important;
240
  color: #475569 !important;
@@ -242,11 +255,16 @@ custom_css = """
242
  display: flex !important;
243
  align-items: center !important;
244
  }
 
245
  .example-title h3::before {
246
  content: "πŸ“š";
247
  margin-right: 8px;
248
  font-size: 1.2em;
249
  }
 
 
 
 
250
  """
251
 
252
  with gr.Blocks(
@@ -273,16 +291,15 @@ with gr.Blocks(
273
  </div>
274
  """)
275
 
276
- with gr.Row(equal_height=True):
277
  with gr.Column(scale=1, variant="panel"):
278
  gr.Markdown("## πŸ“₯ Input Panel")
279
 
280
  with gr.Group():
281
- input_image = gr.Image(
 
282
  type="pil",
283
- tool="sketch",
284
  label="Upload & Annotate",
285
- height=400,
286
  elem_id="custom-image",
287
  interactive=True
288
  )
 
1
  import io
2
  import os
 
 
3
  import shutil
4
  import uuid
5
  import torch
 
104
  scale,
105
  image_paths,
106
  mask_paths
107
+
108
  ):
109
  global image_path, mask_path
110
  gr.Info(str(f"Set seed = {seed}"))
111
  if image_paths is not None:
112
+ input_image["background"] = load_image(image_paths).convert("RGB")
113
+ input_image["layers"][0] = load_image(mask_paths).convert("RGB")
114
 
115
+ size1, size2 = input_image["background"].convert("RGB").size
116
+ icc_profile = input_image["background"].info.get('icc_profile')
 
117
  if icc_profile:
118
  gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
119
  srgb_profile = ImageCms.createProfile("sRGB")
120
  io_handle = io.BytesIO(icc_profile)
121
  src_profile = ImageCms.ImageCmsProfile(io_handle)
122
+ input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
123
+ input_image["background"].info.pop('icc_profile', None)
124
 
125
  if size1 < size2:
126
+ input_image["background"] = input_image["background"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
127
  else:
128
+ input_image["background"] = input_image["background"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
129
 
130
+ img = np.array(input_image["background"].convert("RGB"))
131
 
132
  W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
133
  H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
134
 
135
+ input_image["background"] = input_image["background"].resize((H, W))
136
+ input_image["layers"][0] = input_image["layers"][0].resize((H, W))
137
 
138
  if seed == -1:
139
  seed = random.randint(1, 2147483647)
140
  set_seed(random.randint(1, 2147483647))
141
  else:
142
  set_seed(seed)
143
+ if image_paths is None:
144
+ img=input_image["layers"][0]
145
+ img_data = np.array(img)
146
+ alpha_channel = img_data[:, :, 3]
147
+ white_background = np.ones_like(alpha_channel) * 255
148
+ gray_image = white_background.copy()
149
+ gray_image[alpha_channel == 0] = 0
150
+ gray_image_pil = Image.fromarray(gray_image).convert('L')
151
+ else:
152
+ gray_image_pil = input_image["layers"][0]
153
  result = pipe(
154
  prompt=prompt,
155
+ control_image=input_image["background"].convert("RGB"),
156
+ control_mask=gray_image_pil.convert("RGB"),
157
  width=H,
158
  height=W,
159
  num_inference_steps=ddim_steps,
160
+ generator=torch.Generator("cuda").manual_seed(seed),
161
  guidance_scale=scale,
162
  max_sequence_length=512,
163
  ).images[0]
164
 
165
+ mask_np = np.array(input_image["layers"][0].convert("RGB"))
166
+ red = np.array(input_image["background"]).astype("float") * 1
167
  red[:, :, 0] = 180.0
168
  red[:, :, 2] = 0
169
  red[:, :, 1] = 0
170
+ result_m = np.array(input_image["background"])
171
  result_m = Image.fromarray(
172
  (
173
  result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
174
  ).astype("uint8")
175
  )
176
 
177
+ dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
178
 
179
  dict_out = [result]
180
  image_path = None
 
188
  seed,
189
  scale,
190
  removal_prompt,
191
+
192
  ):
193
  img_path = image_path
194
  msk_path = mask_path
 
212
  mask_path = mask_paths
213
  return masked_image
214
  custom_css = """
215
+
216
  .contain { max-width: 1200px !important; }
217
+
218
  .custom-image {
219
  border: 2px dashed #7e22ce !important;
220
  border-radius: 12px !important;
 
224
  border-color: #9333ea !important;
225
  box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important;
226
  }
227
+
228
  .btn-primary {
229
  background: linear-gradient(45deg, #7e22ce, #9333ea) !important;
230
  border: none !important;
 
237
  padding: 16px !important;
238
  margin-top: 8px !important;
239
  }
240
+
241
  #inline-examples .thumbnail {
242
  border-radius: 8px !important;
243
  transition: transform 0.2s ease !important;
244
  }
245
+
246
  #inline-examples .thumbnail:hover {
247
  transform: scale(1.05);
248
  box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
249
  }
250
+
251
  .example-title h3 {
252
  margin: 0 0 12px 0 !important;
253
  color: #475569 !important;
 
255
  display: flex !important;
256
  align-items: center !important;
257
  }
258
+
259
  .example-title h3::before {
260
  content: "πŸ“š";
261
  margin-right: 8px;
262
  font-size: 1.2em;
263
  }
264
+
265
+ .row { align-items: stretch !important; }
266
+
267
+ .panel { height: 100%; }
268
  """
269
 
270
  with gr.Blocks(
 
291
  </div>
292
  """)
293
 
294
+ with gr.Row(equal_height=False):
295
  with gr.Column(scale=1, variant="panel"):
296
  gr.Markdown("## πŸ“₯ Input Panel")
297
 
298
  with gr.Group():
299
+ input_image = gr.Sketchpad(
300
+ sources=["upload"],
301
  type="pil",
 
302
  label="Upload & Annotate",
 
303
  elem_id="custom-image",
304
  interactive=True
305
  )