NandiniLokeshReddy commited on
Commit
7637f71
·
1 Parent(s): db81b45

Add Gradio app and requirements

Browse files
Files changed (2) hide show
  1. app.py +158 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from torchvision.transforms import ToTensor
6
+ from PIL import Image
7
+ import cv2
8
+ import zipfile
9
+
10
+ # Ensure the necessary model files are available
11
+ !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
12
+ !mkdir -p weights
13
+ !mv sam_vit_h_4b8939.pth weights/
14
+
15
+ !git clone https://github.com/yformer/EfficientSAM.git
16
+ import os
17
+ os.chdir("EfficientSAM")
18
+ !pip install git+https://github.com/facebookresearch/segment-anything.git
19
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
20
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vits
21
+
22
+ # Constants
23
+ DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
24
+ MODEL_TYPE = "vit_h"
25
+ CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
26
+
27
+ # Load SAM model
28
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
29
+ mask_generator_sam = SamAutomaticMaskGenerator(sam)
30
+
31
+ # Load EfficientSAM model
32
+ with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
33
+ zip_ref.extractall("weights")
34
+ efficient_sam_vits_model = build_efficient_sam_vits()
35
+
36
+ from segment_anything.utils.amg import (
37
+ batched_mask_to_box,
38
+ calculate_stability_score,
39
+ mask_to_rle_pytorch,
40
+ remove_small_regions,
41
+ rle_to_mask,
42
+ )
43
+ from torchvision.ops.boxes import batched_nms, box_area
44
+
45
+ def process_small_region(rles):
46
+ new_masks = []
47
+ scores = []
48
+ min_area = 100
49
+ nms_thresh = 0.7
50
+ for rle in rles:
51
+ mask = rle_to_mask(rle[0])
52
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
53
+ unchanged = not changed
54
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
55
+ unchanged = unchanged and not changed
56
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
57
+ scores.append(float(unchanged))
58
+
59
+ masks = torch.cat(new_masks, dim=0)
60
+ boxes = batched_mask_to_box(masks)
61
+ keep_by_nms = batched_nms(
62
+ boxes.float(),
63
+ torch.as_tensor(scores),
64
+ torch.zeros_like(boxes[:, 0]),
65
+ iou_threshold=nms_thresh,
66
+ )
67
+ for i_mask in keep_by_nms:
68
+ if scores[i_mask] == 0.0:
69
+ mask_torch = masks[i_mask].unsqueeze(0)
70
+ rles[i_mask] = mask_to_rle_pytorch(mask_torch)
71
+ masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
72
+ return masks
73
+
74
+ def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
75
+ predicted_masks, predicted_iou = model(
76
+ img[None, ...], points, point_labels
77
+ )
78
+ sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
79
+ predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
80
+ predicted_masks = torch.take_along_dim(
81
+ predicted_masks, sorted_ids[..., None, None], dim=2
82
+ )
83
+ predicted_masks = predicted_masks[0]
84
+ iou = predicted_iou_scores[0, :, 0]
85
+ index_iou = iou > 0.7
86
+ iou_ = iou[index_iou]
87
+ masks = predicted_masks[index_iou]
88
+ score = calculate_stability_score(masks, 0.0, 1.0)
89
+ score = score[:, 0]
90
+ index = score > 0.9
91
+ score_ = score[index]
92
+ masks = masks[index]
93
+ iou_ = iou_[index]
94
+ masks = torch.ge(masks, 0.0)
95
+ return masks, iou_
96
+
97
+ def run_everything_ours(image_np, model):
98
+ model = model.cpu()
99
+ image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
100
+ img_tensor = ToTensor()(image)
101
+ _, original_image_h, original_image_w = img_tensor.shape
102
+ xy = []
103
+ GRID_SIZE = 32
104
+ for i in range(GRID_SIZE):
105
+ curr_x = 0.5 + i / GRID_SIZE * original_image_w
106
+ for j in range(GRID_SIZE):
107
+ curr_y = 0.5 + j / GRID_SIZE * original_image_h
108
+ xy.append([curr_x, curr_y])
109
+ xy = torch.from_numpy(np.array(xy))
110
+ points = xy
111
+ num_pts = xy.shape[0]
112
+ point_labels = torch.ones(num_pts, 1)
113
+ with torch.no_grad():
114
+ predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
115
+ img_tensor.cpu(),
116
+ points.reshape(1, num_pts, 1, 2).cpu(),
117
+ point_labels.reshape(1, num_pts, 1).cpu(),
118
+ model.cpu(),
119
+ )
120
+ rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
121
+ predicted_masks = process_small_region(rle)
122
+ return predicted_masks
123
+
124
+ def show_anns_ours(masks, image):
125
+ for mask in masks:
126
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
127
+ cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
128
+ return image
129
+
130
+ def process_image(image):
131
+ # Convert PIL image to numpy array
132
+ image_np = np.array(image)
133
+
134
+ # Process with SAM
135
+ image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
136
+ sam_result = mask_generator_sam.generate(image_rgb)
137
+
138
+ # Annotate SAM result
139
+ sam_annotated_image = image_np.copy()
140
+ for mask in sam_result:
141
+ sam_annotated_image[mask['segmentation']] = [0, 255, 0]
142
+
143
+ # Process with EfficientSAM
144
+ mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model)
145
+ efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy())
146
+
147
+ return [image, sam_annotated_image, efficient_sam_annotated_image]
148
+
149
+ # Gradio interface
150
+ interface = gr.Interface(
151
+ fn=process_image,
152
+ inputs=gr.Image(type="pil"),
153
+ outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")],
154
+ title="SAM vs EfficientSAM Comparison",
155
+ description="Upload an image to compare the segmentation results of SAM and EfficientSAM."
156
+ )
157
+
158
+ interface.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ gradio
3
+ torch
4
+ torchvision
5
+ opencv-python-headless
6
+ numpy
7
+ Pillow