Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
import cv2
|
3 |
import numpy as np
|
4 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
5 |
import base64
|
@@ -49,37 +48,44 @@ def find_similarity(base64_image, text_input):
|
|
49 |
|
50 |
# Define a function for image segmentation
|
51 |
def segment_image(input_image, text_input):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
60 |
|
61 |
-
|
62 |
-
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
|
63 |
-
segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
|
64 |
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
segmented_image_base64 = base64.b64encode(buffer).decode()
|
71 |
|
72 |
-
|
73 |
-
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
|
80 |
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
# Create Gradio components
|
85 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
4 |
import base64
|
|
|
48 |
|
49 |
# Define a function for image segmentation
|
50 |
def segment_image(input_image, text_input):
|
51 |
+
try:
|
52 |
+
image_bytes = base64.b64decode(input_image)
|
53 |
+
image = Image.open(BytesIO(image_bytes))
|
54 |
+
|
55 |
+
# Convert the image to RGB color mode
|
56 |
+
image = image.convert("RGB")
|
57 |
+
|
58 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
59 |
+
masks = mask_generator.generate(image)
|
60 |
|
61 |
+
segmented_regions = [] # List to store segmented regions with similarity scores
|
|
|
|
|
62 |
|
63 |
+
for i, mask_dict in enumerate(masks):
|
64 |
+
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
|
65 |
+
segmented_region = Image.fromarray(mask_data) # Convert mask to an image
|
66 |
|
67 |
+
x, y, w, h = map(int, mask_dict['bbox'])
|
68 |
+
cropped_region = image.crop((x, y, x + w, y + h))
|
|
|
69 |
|
70 |
+
# Convert to base64 image
|
71 |
+
buffered = BytesIO()
|
72 |
+
cropped_region.save(buffered, format="PNG")
|
73 |
+
segmented_image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
74 |
|
75 |
+
# Calculate similarity for the segmented image
|
76 |
+
similarity = find_similarity(segmented_image_base64, text_input)
|
77 |
|
78 |
+
# Append the segmented image and its similarity score
|
79 |
+
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
|
80 |
|
81 |
+
# Sort the segmented images by similarity in descending order
|
82 |
+
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
|
83 |
+
|
84 |
+
# Return the segmented images in descending order of similarity
|
85 |
+
return segmented_regions
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
return str(e)
|
89 |
|
90 |
# Create Gradio components
|
91 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|