Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
3 |
+
# from IPython.display import display, Image
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from rembg import remove
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
# Content of detectObjects.py file
|
12 |
+
# import detectObjects
|
13 |
+
import ultralytics
|
14 |
+
from ultralytics import YOLO
|
15 |
+
|
16 |
+
model = YOLO('yolov8n.pt')
|
17 |
+
sam_checkpoint = "sam_vit_b_01ec64.pth"
|
18 |
+
model_type = "vit_b"
|
19 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
20 |
+
predictor = SamPredictor(sam)
|
21 |
+
|
22 |
+
def detected_objects(filename:str):
|
23 |
+
results = model.predict(source=filename, conf=0.25)
|
24 |
+
|
25 |
+
categories = results[0].names
|
26 |
+
|
27 |
+
dc = []
|
28 |
+
for i in range(len(results[0])):
|
29 |
+
cat = results[0].boxes[i].cls
|
30 |
+
dc.append(categories[int(cat)])
|
31 |
+
|
32 |
+
print(dc)
|
33 |
+
return results, dc
|
34 |
+
|
35 |
+
def show_mask(mask, ax, random_color=False):
|
36 |
+
if random_color:
|
37 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
38 |
+
else:
|
39 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
40 |
+
h, w = mask.shape[-2:]
|
41 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
42 |
+
ax.imshow(mask_image)
|
43 |
+
|
44 |
+
def show_points(coords, labels, ax, marker_size=375):
|
45 |
+
pos_points = coords[labels==1]
|
46 |
+
neg_points = coords[labels==0]
|
47 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
48 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
49 |
+
|
50 |
+
def show_box(box, ax):
|
51 |
+
x0, y0 = box[0], box[1]
|
52 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
53 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
54 |
+
|
55 |
+
st.title('Extract Objects From Image')
|
56 |
+
|
57 |
+
uploaded_file = st.file_uploader('Upload an image')
|
58 |
+
|
59 |
+
if uploaded_file is not None:
|
60 |
+
# To read file as bytes:
|
61 |
+
bytes_data = uploaded_file.getvalue()
|
62 |
+
with open('uploaded_file.png','wb') as file:
|
63 |
+
file.write(uploaded_file.getvalue())
|
64 |
+
|
65 |
+
# Detect objects in the uploaded image
|
66 |
+
# results, dc = detectObjects.detected_objects('uploaded_file.png')
|
67 |
+
results, dc = detected_objects('uploaded_file.png')
|
68 |
+
|
69 |
+
st.write(dc)
|
70 |
+
|
71 |
+
option = st.selectbox("Which object would you like to extract?", tuple(dc))
|
72 |
+
# print(option)
|
73 |
+
index_of_the_choosen_detected_object = tuple(dc).index(option)
|
74 |
+
|
75 |
+
if st.button('Extract'):
|
76 |
+
for result in results:
|
77 |
+
boxes = result.boxes
|
78 |
+
|
79 |
+
bbox=boxes.xyxy.tolist()[index_of_the_choosen_detected_object]
|
80 |
+
# sam_checkpoint = "sam_vit_b_01ec64.pth"
|
81 |
+
# model_type = "vit_b"
|
82 |
+
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
83 |
+
# predictor = SamPredictor(sam)
|
84 |
+
|
85 |
+
image = cv2.cvtColor(cv2.imread('uploaded_file.png'), cv2.COLOR_BGR2RGB)
|
86 |
+
predictor.set_image(image)
|
87 |
+
|
88 |
+
input_box = np.array(bbox)
|
89 |
+
|
90 |
+
masks, _, _ = predictor.predict(
|
91 |
+
point_coords=None,
|
92 |
+
point_labels=None,
|
93 |
+
box=input_box[None, :],
|
94 |
+
multimask_output=False,
|
95 |
+
)
|
96 |
+
|
97 |
+
# plt.figure(figsize=(10, 10))
|
98 |
+
# st.image(image)
|
99 |
+
# plt.imshow(image)
|
100 |
+
# show_mask(masks[0], plt.gca())
|
101 |
+
# show_box(input_box, plt.gca())
|
102 |
+
# plt.axis('off')
|
103 |
+
# plt.show()
|
104 |
+
|
105 |
+
segmentation_mask = masks[0]
|
106 |
+
binary_mask = np.where(segmentation_mask > 0.5, 1, 0)
|
107 |
+
|
108 |
+
white_background = np.ones_like(image) * 255
|
109 |
+
|
110 |
+
new_image = white_background * (1 - binary_mask[..., np.newaxis]) + image * binary_mask[..., np.newaxis]
|
111 |
+
|
112 |
+
|
113 |
+
plt.imsave('extracted_image.jpg', new_image.astype(np.uint8))
|
114 |
+
# st.image('extracted_image.jpg')
|
115 |
+
|
116 |
+
# Store path of the image in the variable input_path
|
117 |
+
input_path = 'extracted_image.jpg'
|
118 |
+
|
119 |
+
# Store path of the output image in the variable output_path
|
120 |
+
output_path = 'finalExtracted.png'
|
121 |
+
|
122 |
+
# Processing the image
|
123 |
+
input = Image.open(input_path)
|
124 |
+
|
125 |
+
# Removing the background from the given Image
|
126 |
+
output = remove(input)
|
127 |
+
|
128 |
+
#Saving the image in the given path
|
129 |
+
output.save(output_path)
|
130 |
+
# st.image(output_path)
|
131 |
+
|
132 |
+
with open("finalExtracted.png", "rb") as file:
|
133 |
+
btn = st.download_button(
|
134 |
+
label="Download final image",
|
135 |
+
data=file,
|
136 |
+
file_name="finalExtracted.png",
|
137 |
+
mime="image/png",
|
138 |
+
)
|
139 |
+
|
140 |
+
# bbox=boxes.xyxy.tolist()[0]
|