Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,62 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
4 |
import base64
|
5 |
-
from
|
|
|
|
|
|
|
6 |
|
7 |
# Load the segmentation model
|
8 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
9 |
model_type = "vit_h"
|
10 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
image = input_image.astype(np.uint8)
|
16 |
|
17 |
-
# Initialize the mask generator
|
18 |
-
mask_generator = SamAutomaticMaskGenerator(sam)
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
masks = mask_generator.generate(image)
|
22 |
|
23 |
-
|
24 |
-
highest_cosine_base64 = ""
|
25 |
|
26 |
for i, mask_dict in enumerate(masks):
|
27 |
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
|
28 |
-
segmented_region = cv2.bitwise_and(
|
29 |
|
30 |
x, y, w, h = map(int, mask_dict['bbox'])
|
31 |
cropped_region = segmented_region[y:y+h, x:x+w]
|
@@ -34,25 +65,22 @@ def segment_image(input_image):
|
|
34 |
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
|
35 |
segmented_image_base64 = base64.b64encode(buffer).decode()
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
-
result = client.predict(
|
40 |
-
segmented_image_base64, # Base64 Image
|
41 |
-
"Text input", # Text input
|
42 |
-
api_name="/predict"
|
43 |
-
)
|
44 |
|
45 |
-
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
highest_cosine_base64 = segmented_image_base64
|
50 |
|
51 |
-
|
|
|
52 |
|
53 |
# Create Gradio components
|
54 |
input_image = gr.inputs.Image()
|
55 |
-
|
|
|
56 |
|
57 |
# Create a Gradio interface
|
58 |
-
gr.Interface(fn=segment_image, inputs=input_image, outputs=
|
|
|
1 |
import gradio as gr
|
2 |
+
import cv2
|
3 |
import numpy as np
|
4 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
5 |
import base64
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
import torch
|
9 |
+
import clip
|
10 |
|
11 |
# Load the segmentation model
|
12 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
13 |
model_type = "vit_h"
|
14 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
15 |
|
16 |
+
# Load the CLIP model
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
|
|
19 |
|
|
|
|
|
20 |
|
21 |
+
def find_similarity(base64_image, text_input):
|
22 |
+
# Decode the base64 image to bytes
|
23 |
+
image_bytes = base64.b64decode(base64_image)
|
24 |
+
|
25 |
+
# Convert the bytes to a PIL image
|
26 |
+
image = Image.open(BytesIO(image_bytes))
|
27 |
+
|
28 |
+
# Preprocess the image
|
29 |
+
image = preprocess(image).unsqueeze(0).to(device)
|
30 |
+
|
31 |
+
# Prepare input text
|
32 |
+
text_tokens = clip.tokenize([text_input]).to(device)
|
33 |
+
|
34 |
+
# Encode image and text features
|
35 |
+
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
image_features = model.encode_image(image)
|
39 |
+
text_features = model.encode_text(text_tokens)
|
40 |
+
|
41 |
+
# Normalize features and calculate similarity
|
42 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
43 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
44 |
+
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
|
45 |
+
|
46 |
+
return similarity
|
47 |
+
|
48 |
+
|
49 |
+
# Define a function for image segmentation
|
50 |
+
def segment_image(input_image, text_input):
|
51 |
+
image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
|
52 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
53 |
masks = mask_generator.generate(image)
|
54 |
|
55 |
+
segmented_regions = [] # List to store segmented regions with similarity scores
|
|
|
56 |
|
57 |
for i, mask_dict in enumerate(masks):
|
58 |
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
|
59 |
+
segmented_region = cv2.bitwise_and(input_image, input_image, mask=mask_data)
|
60 |
|
61 |
x, y, w, h = map(int, mask_dict['bbox'])
|
62 |
cropped_region = segmented_region[y:y+h, x:x+w]
|
|
|
65 |
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
|
66 |
segmented_image_base64 = base64.b64encode(buffer).decode()
|
67 |
|
68 |
+
# Calculate similarity for the segmented image
|
69 |
+
similarity = find_similarity(segmented_image_base64, text_input)
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
# Append the segmented image and its similarity score
|
72 |
+
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
|
73 |
|
74 |
+
# Sort the segmented images by similarity in descending order
|
75 |
+
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
|
|
|
76 |
|
77 |
+
# Return the segmented images in descending order of similarity
|
78 |
+
return segmented_regions
|
79 |
|
80 |
# Create Gradio components
|
81 |
input_image = gr.inputs.Image()
|
82 |
+
text_input = gr.inputs.Text()
|
83 |
+
output_images = gr.outputs.JSON()
|
84 |
|
85 |
# Create a Gradio interface
|
86 |
+
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()
|