Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -157,7 +157,12 @@ def predict(images, resolution, weights_file):
|
|
157 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
158 |
image_masked.save(save_file_path)
|
159 |
save_paths.append(save_file_path)
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
if tab_is_batch:
|
163 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
|
|
157 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
158 |
image_masked.save(save_file_path)
|
159 |
save_paths.append(save_file_path)
|
160 |
+
|
161 |
+
# Apply the prediction mask to the original image
|
162 |
+
pred = torch.nn.functional.interpolate(pred, size=image_shape, mode='bilinear', align_corners=True).numpy()
|
163 |
+
image_pil = image_pil.resize(pred.shape[::-1])
|
164 |
+
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
165 |
+
image_masked = (pred * np.array(image_pil)).astype(np.uint8)
|
166 |
|
167 |
if tab_is_batch:
|
168 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|