ZhengPeng7 commited on
Commit
84abebf
·
verified ·
1 Parent(s): 6df6515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -157,7 +157,12 @@ def predict(images, resolution, weights_file):
157
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
158
  image_masked.save(save_file_path)
159
  save_paths.append(save_file_path)
160
- image_masked = np.array(image_masked.convert('RGB'))
 
 
 
 
 
161
 
162
  if tab_is_batch:
163
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
 
157
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
158
  image_masked.save(save_file_path)
159
  save_paths.append(save_file_path)
160
+
161
+ # Apply the prediction mask to the original image
162
+ pred = torch.nn.functional.interpolate(pred, size=image_shape, mode='bilinear', align_corners=True).numpy()
163
+ image_pil = image_pil.resize(pred.shape[::-1])
164
+ pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
165
+ image_masked = (pred * np.array(image_pil)).astype(np.uint8)
166
 
167
  if tab_is_batch:
168
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))