Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
|
|
6 |
from io import BytesIO
|
7 |
import torch
|
8 |
import clip
|
|
|
9 |
|
10 |
# Load the segmentation model
|
11 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
@@ -17,7 +18,6 @@ model, preprocess = clip.load("ViT-B/32")
|
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
model.to(device).eval()
|
19 |
|
20 |
-
|
21 |
def find_similarity(base64_image, text_input):
|
22 |
try:
|
23 |
# Decode the base64 image to bytes
|
@@ -44,9 +44,8 @@ def find_similarity(base64_image, text_input):
|
|
44 |
|
45 |
return similarity
|
46 |
except Exception as e:
|
47 |
-
return str(e)
|
48 |
|
49 |
-
# Define a function for image segmentation
|
50 |
def segment_image(input_image, text_input):
|
51 |
try:
|
52 |
image_bytes = base64.b64decode(input_image)
|
@@ -81,11 +80,11 @@ def segment_image(input_image, text_input):
|
|
81 |
# Sort the segmented images by similarity in descending order
|
82 |
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
|
83 |
|
84 |
-
# Return the segmented images in
|
85 |
-
return segmented_regions
|
86 |
|
87 |
except Exception as e:
|
88 |
-
return str(e)
|
89 |
|
90 |
# Create Gradio components
|
91 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|
|
|
6 |
from io import BytesIO
|
7 |
import torch
|
8 |
import clip
|
9 |
+
import json
|
10 |
|
11 |
# Load the segmentation model
|
12 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
|
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
model.to(device).eval()
|
20 |
|
|
|
21 |
def find_similarity(base64_image, text_input):
|
22 |
try:
|
23 |
# Decode the base64 image to bytes
|
|
|
44 |
|
45 |
return similarity
|
46 |
except Exception as e:
|
47 |
+
return json.dumps({"error": str(e)})
|
48 |
|
|
|
49 |
def segment_image(input_image, text_input):
|
50 |
try:
|
51 |
image_bytes = base64.b64decode(input_image)
|
|
|
80 |
# Sort the segmented images by similarity in descending order
|
81 |
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
|
82 |
|
83 |
+
# Return the segmented images in a JSON format
|
84 |
+
return json.dumps(segmented_regions)
|
85 |
|
86 |
except Exception as e:
|
87 |
+
return json.dumps({"error": str(e})
|
88 |
|
89 |
# Create Gradio components
|
90 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|