jens commited on
Commit
c93bd34
·
1 Parent(s): 2f5a3a2

mirror image and reverse sort

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. inference.py +1 -1
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
5
- from PIL import Image
6
  import torch
7
  from inference import SegmentPredictor, DepthPredictor
8
  from utils import generate_PCL, PCL3, point_cloud
@@ -76,10 +76,11 @@ with block:
76
  sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, cube_size, selected_masks_image}
77
 
78
  def on_upload_image(input_image, upload_image):
79
- return [upload_image, upload_image]
 
 
80
  upload_image.upload(on_upload_image, [input_image, upload_image], [input_image, upload_image])
81
 
82
-
83
  # event - init coords
84
  def on_reset_btn_click(input_image):
85
  return input_image, point_coords_empty(), point_labels_empty(), None, []
 
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
5
+ from PIL import Image, ImageOps
6
  import torch
7
  from inference import SegmentPredictor, DepthPredictor
8
  from utils import generate_PCL, PCL3, point_cloud
 
76
  sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, cube_size, selected_masks_image}
77
 
78
  def on_upload_image(input_image, upload_image):
79
+ ## Mirror because gradio.image webcam has mirror = True
80
+ upload_image_mirror = ImageOps.mirror(upload_image)
81
+ return [upload_image_mirror, upload_image]
82
  upload_image.upload(on_upload_image, [input_image, upload_image], [input_image, upload_image])
83
 
 
84
  # event - init coords
85
  def on_reset_btn_click(input_image):
86
  return input_image, point_coords_empty(), point_labels_empty(), None, []
inference.py CHANGED
@@ -189,7 +189,7 @@ class SegmentPredictor:
189
  point_labels=lbls,
190
  multimask_output=True
191
  )
192
- idxs = np.argsort(masks.sum(axis=(1,2)))
193
  sam_masks = []
194
  for n,i in enumerate(idxs):
195
  sam_masks.append((masks[i], str(n)))
 
189
  point_labels=lbls,
190
  multimask_output=True
191
  )
192
+ idxs = np.argsort(-masks.sum(axis=(1,2)))
193
  sam_masks = []
194
  for n,i in enumerate(idxs):
195
  sam_masks.append((masks[i], str(n)))