beingcognitive commited on
Commit
6435d5a
·
verified ·
1 Parent(s): bad2a63

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoProcessor, AutoModelForMaskGeneration
3
+ from transformers import pipeline
4
+ from PIL import Image, ImageOps
5
+ # from PIL import Image
6
+ import numpy as np
7
+ # import matplotlib.pyplot as plt
8
+ import torch
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ def main():
13
+ st.title("Image Segmentation")
14
+
15
+ # Load SAM by Facebook
16
+ processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
17
+ model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-huge")
18
+ # Load Object Detection
19
+ od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
20
+
21
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
22
+
23
+ xs_ys = [(2.0, 2.0), (2.5, 2.5)] #, (2.5, 2.0), (2.0, 2.5), (1.5, 1.5)]
24
+ alpha = 20
25
+ width = 600
26
+
27
+ if uploaded_file is not None:
28
+ raw_image = Image.open(uploaded_file)
29
+
30
+ st.subheader("Uploaded Image")
31
+ st.image(raw_image, caption="Uploaded Image", width=width)
32
+
33
+ ### STEP 1. Object Detection
34
+ pipeline_output = od_pipe(raw_image)
35
+
36
+ # Convert the bounding boxes from the pipeline output into the expected format for the SAM processor
37
+ input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
38
+ labels_format = [b['label'] for b in pipeline_output]
39
+ print(input_boxes_format)
40
+ print(labels_format)
41
+
42
+ # Now use these formatted boxes with the processor
43
+ for b, l in zip(input_boxes_format, labels_format):
44
+ with st.spinner('Processing...'):
45
+
46
+ st.subheader(f'bounding box : {l}')
47
+ inputs = processor(images=raw_image,
48
+ input_boxes=[b],
49
+ return_tensors="pt")
50
+
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+
54
+ predicted_masks = processor.image_processor.post_process_masks(
55
+ outputs.pred_masks,
56
+ inputs["original_sizes"],
57
+ inputs["reshaped_input_sizes"]
58
+ )
59
+ predicted_mask = predicted_masks[0]
60
+
61
+ for i in range(0, 3):
62
+ # 2D array (boolean mask)
63
+ mask = predicted_mask[0][i]
64
+ int_mask = np.array(mask).astype(int) * 255
65
+ mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
66
+
67
+ # Apply the mask to the image
68
+ # Convert mask to a 3-channel image if your base image is in RGB
69
+ mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
70
+ final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
71
+
72
+ #display the final image
73
+ st.image(final_image, caption=f"Masked Image {i+1}", width=width)
74
+
75
+ ###
76
+ for (x, y) in xs_ys:
77
+ with st.spinner('Processing...'):
78
+
79
+ # Calculate input points
80
+ point_x = raw_image.size[0] // x
81
+ point_y = raw_image.size[1] // y
82
+ input_points = [[[ point_x, point_y ]]]
83
+
84
+ # Prepare inputs
85
+ inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
86
+
87
+ # Generate masks
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+
91
+ # Post-process masks
92
+ predicted_masks = processor.image_processor.post_process_masks(
93
+ outputs.pred_masks,
94
+ inputs["original_sizes"],
95
+ inputs["reshaped_input_sizes"]
96
+ )
97
+
98
+ predicted_mask = predicted_masks[0]
99
+
100
+ # Display masked images
101
+ st.subheader(f"Input points : ({1/x},{1/y})")
102
+ for i in range(3):
103
+ mask = predicted_mask[0][i]
104
+ int_mask = np.array(mask).astype(int) * 255
105
+ mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
106
+
107
+ ###
108
+ mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
109
+ final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
110
+
111
+ st.image(final_image, caption=f"Masked Image {i+1}", width=width)
112
+
113
+ if __name__ == "__main__":
114
+ main()