Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,13 @@ import streamlit as st
|
|
6 |
from PIL import Image
|
7 |
from transformers import SamModel, SamProcessor
|
8 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
|
@@ -15,21 +22,6 @@ MAX_WIDTH = 700
|
|
15 |
|
16 |
|
17 |
# Define helpful functions
|
18 |
-
def show_anns(anns):
|
19 |
-
if len(anns) == 0:
|
20 |
-
return
|
21 |
-
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
22 |
-
ax = plt.gca()
|
23 |
-
ax.set_autoscale_on(False)
|
24 |
-
|
25 |
-
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
|
26 |
-
img[:,:,3] = 0
|
27 |
-
for ann in sorted_anns:
|
28 |
-
m = ann['segmentation']
|
29 |
-
color_mask = np.concatenate([np.random.random(3), [0.35]])
|
30 |
-
img[m] = color_mask
|
31 |
-
ax.imshow(img)
|
32 |
-
|
33 |
def show_mask(mask, ax, random_color=False):
|
34 |
if random_color:
|
35 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
@@ -43,22 +35,6 @@ def show_points(coords, labels, ax, marker_size=20):
|
|
43 |
pos_points = coords[labels==1]
|
44 |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2)
|
45 |
|
46 |
-
def show_masks_on_image(raw_image, masks, scores):
|
47 |
-
if len(masks.shape) == 4:
|
48 |
-
masks = masks.squeeze()
|
49 |
-
if scores.shape[0] == 1:
|
50 |
-
scores = scores.squeeze()
|
51 |
-
|
52 |
-
nb_predictions = scores.shape[-1]
|
53 |
-
fig, ax = plt.subplots(1, nb_predictions)
|
54 |
-
|
55 |
-
for i, (mask, score) in enumerate(zip(masks, scores)):
|
56 |
-
mask = mask.cpu().detach()
|
57 |
-
ax[i].imshow(np.array(raw_image))
|
58 |
-
show_mask(mask, ax[i])
|
59 |
-
ax[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
|
60 |
-
ax[i].axis("off")
|
61 |
-
|
62 |
def show_points_on_image(raw_image, input_point, ax, input_labels=None):
|
63 |
ax.imshow(raw_image)
|
64 |
input_point = np.array(input_point)
|
@@ -71,7 +47,6 @@ def show_points_on_image(raw_image, input_point, ax, input_labels=None):
|
|
71 |
|
72 |
|
73 |
|
74 |
-
|
75 |
# Get SAM
|
76 |
if torch.cuda.is_available():
|
77 |
device = 'cuda'
|
@@ -93,12 +68,22 @@ if scale:
|
|
93 |
scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8)
|
94 |
scale_np = cv2.imdecode(scale_np, 1)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
#inputs = processor(raw_image, return_tensors="pt").to(device)
|
97 |
inputs = processor(scale_np, return_tensors="pt").to(device)
|
98 |
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
99 |
|
100 |
scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
|
101 |
-
clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
|
|
|
102 |
if clicked_point:
|
103 |
input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
|
104 |
input_point_list = [input_point_np.astype(int).tolist()]
|
@@ -137,12 +122,21 @@ if image:
|
|
137 |
image_np = np.asarray(bytearray(image.read()), dtype=np.uint8)
|
138 |
image_np = cv2.imdecode(image_np, 1)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
#inputs = processor(raw_image, return_tensors="pt").to(device)
|
141 |
inputs = processor(image_np, return_tensors="pt").to(device)
|
142 |
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
143 |
|
144 |
scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
|
145 |
-
clicked_point = streamlit_image_coordinates(
|
146 |
if clicked_point:
|
147 |
input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
|
148 |
input_point_list = [input_point_np.astype(int).tolist()]
|
|
|
6 |
from PIL import Image
|
7 |
from transformers import SamModel, SamProcessor
|
8 |
import cv2
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
# Empty the images folder before starting
|
14 |
+
for path in os.listdir('images'):
|
15 |
+
os.remove(path)
|
16 |
|
17 |
|
18 |
|
|
|
22 |
|
23 |
|
24 |
# Define helpful functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def show_mask(mask, ax, random_color=False):
|
26 |
if random_color:
|
27 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
35 |
pos_points = coords[labels==1]
|
36 |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def show_points_on_image(raw_image, input_point, ax, input_labels=None):
|
39 |
ax.imshow(raw_image)
|
40 |
input_point = np.array(input_point)
|
|
|
47 |
|
48 |
|
49 |
|
|
|
50 |
# Get SAM
|
51 |
if torch.cuda.is_available():
|
52 |
device = 'cuda'
|
|
|
68 |
scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8)
|
69 |
scale_np = cv2.imdecode(scale_np, 1)
|
70 |
|
71 |
+
# Save image if it isn't already saved
|
72 |
+
if not os.path.exists(os.path.join("images/", scale.name)):
|
73 |
+
with open(os.path.join("images/", scale.name), "wb") as f:
|
74 |
+
f.write(scale.getbuffer())
|
75 |
+
scale_pil = Image.open(os.path.join("images/", scale.name))
|
76 |
+
|
77 |
+
# Remove file when done
|
78 |
+
os.remove(os.path.join("images/", scale.name))
|
79 |
+
|
80 |
#inputs = processor(raw_image, return_tensors="pt").to(device)
|
81 |
inputs = processor(scale_np, return_tensors="pt").to(device)
|
82 |
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
83 |
|
84 |
scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
|
85 |
+
#clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
|
86 |
+
clicked_point = streamlit_image_coordinates(scale_pil, height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
|
87 |
if clicked_point:
|
88 |
input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
|
89 |
input_point_list = [input_point_np.astype(int).tolist()]
|
|
|
122 |
image_np = np.asarray(bytearray(image.read()), dtype=np.uint8)
|
123 |
image_np = cv2.imdecode(image_np, 1)
|
124 |
|
125 |
+
# Save image if it isn't already saved
|
126 |
+
if not os.path.exists(os.path.join("images/", image.name)):
|
127 |
+
with open(os.path.join("images/", image.name), "wb") as f:
|
128 |
+
f.write(image.getbuffer())
|
129 |
+
image_pil = Image.open(os.path.join("images/", image.name))
|
130 |
+
|
131 |
+
# Remove file when done
|
132 |
+
os.remove(os.path.join("images/", image.name))
|
133 |
+
|
134 |
#inputs = processor(raw_image, return_tensors="pt").to(device)
|
135 |
inputs = processor(image_np, return_tensors="pt").to(device)
|
136 |
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
137 |
|
138 |
scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
|
139 |
+
clicked_point = streamlit_image_coordinates(image_pil, height=image_np.shape[0] // scale_factor, width=MAX_WIDTH)
|
140 |
if clicked_point:
|
141 |
input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
|
142 |
input_point_list = [input_point_np.astype(int).tolist()]
|