Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
4 |
import base64
|
@@ -6,7 +7,6 @@ from PIL import Image
|
|
6 |
from io import BytesIO
|
7 |
import torch
|
8 |
import clip
|
9 |
-
import json
|
10 |
|
11 |
# Load the segmentation model
|
12 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
@@ -18,10 +18,16 @@ model, preprocess = clip.load("ViT-B/32")
|
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
model.to(device).eval()
|
20 |
|
21 |
-
|
|
|
22 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Preprocess the image
|
24 |
-
image = Image.fromarray(image)
|
25 |
image = preprocess(image).unsqueeze(0).to(device)
|
26 |
|
27 |
# Prepare input text
|
@@ -37,45 +43,43 @@ def find_similarity(image, text_input):
|
|
37 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
38 |
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
|
39 |
|
40 |
-
return similarity
|
41 |
-
|
42 |
except Exception as e:
|
43 |
-
return
|
44 |
|
|
|
45 |
def segment_image(input_image, text_input):
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
masks = mask_generator.generate(np.array(image))
|
55 |
|
56 |
-
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
segmented_region = Image.fromarray(mask_data) # Convert mask to an image
|
61 |
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
except Exception as e:
|
78 |
-
return json.dumps({"error": str(e)})
|
79 |
|
80 |
# Create Gradio components
|
81 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|
|
|
1 |
import gradio as gr
|
2 |
+
import cv2
|
3 |
import numpy as np
|
4 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
5 |
import base64
|
|
|
7 |
from io import BytesIO
|
8 |
import torch
|
9 |
import clip
|
|
|
10 |
|
11 |
# Load the segmentation model
|
12 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
|
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
model.to(device).eval()
|
20 |
|
21 |
+
|
22 |
+
def find_similarity(base64_image, text_input):
|
23 |
try:
|
24 |
+
# Decode the base64 image to bytes
|
25 |
+
image_bytes = base64.b64decode(base64_image)
|
26 |
+
|
27 |
+
# Convert the bytes to a PIL image
|
28 |
+
image = Image.open(BytesIO(image_bytes))
|
29 |
+
|
30 |
# Preprocess the image
|
|
|
31 |
image = preprocess(image).unsqueeze(0).to(device)
|
32 |
|
33 |
# Prepare input text
|
|
|
43 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
44 |
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
|
45 |
|
46 |
+
return similarity
|
|
|
47 |
except Exception as e:
|
48 |
+
return str(e)
|
49 |
|
50 |
+
# Define a function for image segmentation
|
51 |
def segment_image(input_image, text_input):
|
52 |
+
image_bytes = base64.b64decode(input_image)
|
53 |
+
image = Image.open(BytesIO(image_bytes))
|
54 |
+
|
55 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
|
56 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
57 |
+
masks = mask_generator.generate(image)
|
58 |
|
59 |
+
segmented_regions = [] # List to store segmented regions with similarity scores
|
|
|
60 |
|
61 |
+
for i, mask_dict in enumerate(masks):
|
62 |
+
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
|
63 |
+
segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
|
64 |
|
65 |
+
x, y, w, h = map(int, mask_dict['bbox'])
|
66 |
+
cropped_region = segmented_region[y:y+h, x:x+w]
|
|
|
67 |
|
68 |
+
# Convert to base64 image
|
69 |
+
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
|
70 |
+
segmented_image_base64 = base64.b64encode(buffer).decode()
|
71 |
|
72 |
+
# Calculate similarity for the segmented image
|
73 |
+
similarity = find_similarity(segmented_image_base64, text_input)
|
74 |
|
75 |
+
# Append the segmented image and its similarity score
|
76 |
+
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
|
77 |
|
78 |
+
# Sort the segmented images by similarity in descending order
|
79 |
+
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
|
80 |
|
81 |
+
# Return the segmented images in descending order of similarity
|
82 |
+
return segmented_regions
|
|
|
|
|
|
|
83 |
|
84 |
# Create Gradio components
|
85 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|