rsrao1729 commited on
Commit
cb7ed6d
·
1 Parent(s): 49ab870

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -34
app.py CHANGED
@@ -6,6 +6,13 @@ import streamlit as st
6
  from PIL import Image
7
  from transformers import SamModel, SamProcessor
8
  import cv2
 
 
 
 
 
 
 
9
 
10
 
11
 
@@ -15,21 +22,6 @@ MAX_WIDTH = 700
15
 
16
 
17
  # Define helpful functions
18
- def show_anns(anns):
19
- if len(anns) == 0:
20
- return
21
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
22
- ax = plt.gca()
23
- ax.set_autoscale_on(False)
24
-
25
- img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
26
- img[:,:,3] = 0
27
- for ann in sorted_anns:
28
- m = ann['segmentation']
29
- color_mask = np.concatenate([np.random.random(3), [0.35]])
30
- img[m] = color_mask
31
- ax.imshow(img)
32
-
33
  def show_mask(mask, ax, random_color=False):
34
  if random_color:
35
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -43,22 +35,6 @@ def show_points(coords, labels, ax, marker_size=20):
43
  pos_points = coords[labels==1]
44
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2)
45
 
46
- def show_masks_on_image(raw_image, masks, scores):
47
- if len(masks.shape) == 4:
48
- masks = masks.squeeze()
49
- if scores.shape[0] == 1:
50
- scores = scores.squeeze()
51
-
52
- nb_predictions = scores.shape[-1]
53
- fig, ax = plt.subplots(1, nb_predictions)
54
-
55
- for i, (mask, score) in enumerate(zip(masks, scores)):
56
- mask = mask.cpu().detach()
57
- ax[i].imshow(np.array(raw_image))
58
- show_mask(mask, ax[i])
59
- ax[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
60
- ax[i].axis("off")
61
-
62
  def show_points_on_image(raw_image, input_point, ax, input_labels=None):
63
  ax.imshow(raw_image)
64
  input_point = np.array(input_point)
@@ -71,7 +47,6 @@ def show_points_on_image(raw_image, input_point, ax, input_labels=None):
71
 
72
 
73
 
74
-
75
  # Get SAM
76
  if torch.cuda.is_available():
77
  device = 'cuda'
@@ -93,12 +68,22 @@ if scale:
93
  scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8)
94
  scale_np = cv2.imdecode(scale_np, 1)
95
 
 
 
 
 
 
 
 
 
 
96
  #inputs = processor(raw_image, return_tensors="pt").to(device)
97
  inputs = processor(scale_np, return_tensors="pt").to(device)
98
  image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
99
 
100
  scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
101
- clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
 
102
  if clicked_point:
103
  input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
104
  input_point_list = [input_point_np.astype(int).tolist()]
@@ -137,12 +122,21 @@ if image:
137
  image_np = np.asarray(bytearray(image.read()), dtype=np.uint8)
138
  image_np = cv2.imdecode(image_np, 1)
139
 
 
 
 
 
 
 
 
 
 
140
  #inputs = processor(raw_image, return_tensors="pt").to(device)
141
  inputs = processor(image_np, return_tensors="pt").to(device)
142
  image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
143
 
144
  scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
145
- clicked_point = streamlit_image_coordinates(Image.open(image.name), height=image_np.shape[0] // scale_factor, width=MAX_WIDTH)
146
  if clicked_point:
147
  input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
148
  input_point_list = [input_point_np.astype(int).tolist()]
 
6
  from PIL import Image
7
  from transformers import SamModel, SamProcessor
8
  import cv2
9
+ import os
10
+
11
+
12
+
13
+ # Empty the images folder before starting
14
+ for path in os.listdir('images'):
15
+ os.remove(path)
16
 
17
 
18
 
 
22
 
23
 
24
  # Define helpful functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def show_mask(mask, ax, random_color=False):
26
  if random_color:
27
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
35
  pos_points = coords[labels==1]
36
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def show_points_on_image(raw_image, input_point, ax, input_labels=None):
39
  ax.imshow(raw_image)
40
  input_point = np.array(input_point)
 
47
 
48
 
49
 
 
50
  # Get SAM
51
  if torch.cuda.is_available():
52
  device = 'cuda'
 
68
  scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8)
69
  scale_np = cv2.imdecode(scale_np, 1)
70
 
71
+ # Save image if it isn't already saved
72
+ if not os.path.exists(os.path.join("images/", scale.name)):
73
+ with open(os.path.join("images/", scale.name), "wb") as f:
74
+ f.write(scale.getbuffer())
75
+ scale_pil = Image.open(os.path.join("images/", scale.name))
76
+
77
+ # Remove file when done
78
+ os.remove(os.path.join("images/", scale.name))
79
+
80
  #inputs = processor(raw_image, return_tensors="pt").to(device)
81
  inputs = processor(scale_np, return_tensors="pt").to(device)
82
  image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
83
 
84
  scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
85
+ #clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
86
+ clicked_point = streamlit_image_coordinates(scale_pil, height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
87
  if clicked_point:
88
  input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
89
  input_point_list = [input_point_np.astype(int).tolist()]
 
122
  image_np = np.asarray(bytearray(image.read()), dtype=np.uint8)
123
  image_np = cv2.imdecode(image_np, 1)
124
 
125
+ # Save image if it isn't already saved
126
+ if not os.path.exists(os.path.join("images/", image.name)):
127
+ with open(os.path.join("images/", image.name), "wb") as f:
128
+ f.write(image.getbuffer())
129
+ image_pil = Image.open(os.path.join("images/", image.name))
130
+
131
+ # Remove file when done
132
+ os.remove(os.path.join("images/", image.name))
133
+
134
  #inputs = processor(raw_image, return_tensors="pt").to(device)
135
  inputs = processor(image_np, return_tensors="pt").to(device)
136
  image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
137
 
138
  scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
139
+ clicked_point = streamlit_image_coordinates(image_pil, height=image_np.shape[0] // scale_factor, width=MAX_WIDTH)
140
  if clicked_point:
141
  input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
142
  input_point_list = [input_point_np.astype(int).tolist()]