krishanwalia30 commited on
Commit
7613b6c
·
verified ·
1 Parent(s): e16dacf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
3
+ # from IPython.display import display, Image
4
+ import cv2
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from rembg import remove
8
+ from PIL import Image
9
+
10
+
11
+ # Content of detectObjects.py file
12
+ # import detectObjects
13
+ import ultralytics
14
+ from ultralytics import YOLO
15
+
16
+ model = YOLO('yolov8n.pt')
17
+ sam_checkpoint = "sam_vit_b_01ec64.pth"
18
+ model_type = "vit_b"
19
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
20
+ predictor = SamPredictor(sam)
21
+
22
+ def detected_objects(filename:str):
23
+ results = model.predict(source=filename, conf=0.25)
24
+
25
+ categories = results[0].names
26
+
27
+ dc = []
28
+ for i in range(len(results[0])):
29
+ cat = results[0].boxes[i].cls
30
+ dc.append(categories[int(cat)])
31
+
32
+ print(dc)
33
+ return results, dc
34
+
35
+ def show_mask(mask, ax, random_color=False):
36
+ if random_color:
37
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
38
+ else:
39
+ color = np.array([30/255, 144/255, 255/255, 0.6])
40
+ h, w = mask.shape[-2:]
41
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
42
+ ax.imshow(mask_image)
43
+
44
+ def show_points(coords, labels, ax, marker_size=375):
45
+ pos_points = coords[labels==1]
46
+ neg_points = coords[labels==0]
47
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
48
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
49
+
50
+ def show_box(box, ax):
51
+ x0, y0 = box[0], box[1]
52
+ w, h = box[2] - box[0], box[3] - box[1]
53
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
54
+
55
+ st.title('Extract Objects From Image')
56
+
57
+ uploaded_file = st.file_uploader('Upload an image')
58
+
59
+ if uploaded_file is not None:
60
+ # To read file as bytes:
61
+ bytes_data = uploaded_file.getvalue()
62
+ with open('uploaded_file.png','wb') as file:
63
+ file.write(uploaded_file.getvalue())
64
+
65
+ # Detect objects in the uploaded image
66
+ # results, dc = detectObjects.detected_objects('uploaded_file.png')
67
+ results, dc = detected_objects('uploaded_file.png')
68
+
69
+ st.write(dc)
70
+
71
+ option = st.selectbox("Which object would you like to extract?", tuple(dc))
72
+ # print(option)
73
+ index_of_the_choosen_detected_object = tuple(dc).index(option)
74
+
75
+ if st.button('Extract'):
76
+ for result in results:
77
+ boxes = result.boxes
78
+
79
+ bbox=boxes.xyxy.tolist()[index_of_the_choosen_detected_object]
80
+ # sam_checkpoint = "sam_vit_b_01ec64.pth"
81
+ # model_type = "vit_b"
82
+ # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
83
+ # predictor = SamPredictor(sam)
84
+
85
+ image = cv2.cvtColor(cv2.imread('uploaded_file.png'), cv2.COLOR_BGR2RGB)
86
+ predictor.set_image(image)
87
+
88
+ input_box = np.array(bbox)
89
+
90
+ masks, _, _ = predictor.predict(
91
+ point_coords=None,
92
+ point_labels=None,
93
+ box=input_box[None, :],
94
+ multimask_output=False,
95
+ )
96
+
97
+ # plt.figure(figsize=(10, 10))
98
+ # st.image(image)
99
+ # plt.imshow(image)
100
+ # show_mask(masks[0], plt.gca())
101
+ # show_box(input_box, plt.gca())
102
+ # plt.axis('off')
103
+ # plt.show()
104
+
105
+ segmentation_mask = masks[0]
106
+ binary_mask = np.where(segmentation_mask > 0.5, 1, 0)
107
+
108
+ white_background = np.ones_like(image) * 255
109
+
110
+ new_image = white_background * (1 - binary_mask[..., np.newaxis]) + image * binary_mask[..., np.newaxis]
111
+
112
+
113
+ plt.imsave('extracted_image.jpg', new_image.astype(np.uint8))
114
+ # st.image('extracted_image.jpg')
115
+
116
+ # Store path of the image in the variable input_path
117
+ input_path = 'extracted_image.jpg'
118
+
119
+ # Store path of the output image in the variable output_path
120
+ output_path = 'finalExtracted.png'
121
+
122
+ # Processing the image
123
+ input = Image.open(input_path)
124
+
125
+ # Removing the background from the given Image
126
+ output = remove(input)
127
+
128
+ #Saving the image in the given path
129
+ output.save(output_path)
130
+ # st.image(output_path)
131
+
132
+ with open("finalExtracted.png", "rb") as file:
133
+ btn = st.download_button(
134
+ label="Download final image",
135
+ data=file,
136
+ file_name="finalExtracted.png",
137
+ mime="image/png",
138
+ )
139
+
140
+ # bbox=boxes.xyxy.tolist()[0]