Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
|
|
7 |
from io import BytesIO
|
8 |
import torch
|
9 |
import clip
|
|
|
10 |
|
11 |
# Load the segmentation model
|
12 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
@@ -47,7 +48,6 @@ def find_similarity(base64_image, text_input):
|
|
47 |
except Exception as e:
|
48 |
return str(e)
|
49 |
|
50 |
-
# Define a function for image segmentation
|
51 |
def segment_image(input_image, text_input):
|
52 |
image_bytes = base64.b64decode(input_image)
|
53 |
image = Image.open(BytesIO(image_bytes))
|
@@ -92,12 +92,25 @@ def segment_image(input_image, text_input):
|
|
92 |
# Limit the output to the top 6 key-value pairs
|
93 |
segmented_regions = segmented_regions[:6]
|
94 |
|
95 |
-
#
|
|
|
96 |
for result in segmented_regions:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
print(result)
|
98 |
|
99 |
# Return the segmented images in descending order of similarity
|
100 |
-
return
|
101 |
|
102 |
# Create Gradio components
|
103 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|
|
|
7 |
from io import BytesIO
|
8 |
import torch
|
9 |
import clip
|
10 |
+
from gradio_client import Client
|
11 |
|
12 |
# Load the segmentation model
|
13 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
|
|
48 |
except Exception as e:
|
49 |
return str(e)
|
50 |
|
|
|
51 |
def segment_image(input_image, text_input):
|
52 |
image_bytes = base64.b64decode(input_image)
|
53 |
image = Image.open(BytesIO(image_bytes))
|
|
|
92 |
# Limit the output to the top 6 key-value pairs
|
93 |
segmented_regions = segmented_regions[:6]
|
94 |
|
95 |
+
# Run each segmented image through the inference API and store results
|
96 |
+
results_with_similarity = []
|
97 |
for result in segmented_regions:
|
98 |
+
image_base64 = result["image"]
|
99 |
+
similarity = result["similarity"]
|
100 |
+
|
101 |
+
# Make an API call with image_base64, get the API result
|
102 |
+
client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/dv889/")
|
103 |
+
api_result = client.predict(image_base64, text_input, api_name="/predict")
|
104 |
+
|
105 |
+
# Append the API result and similarity to the list
|
106 |
+
results_with_similarity.append({"api_result": api_result, "similarity": similarity})
|
107 |
+
|
108 |
+
# Print the top 6 results
|
109 |
+
for result in results_with_similarity:
|
110 |
print(result)
|
111 |
|
112 |
# Return the segmented images in descending order of similarity
|
113 |
+
return results_with_similarity
|
114 |
|
115 |
# Create Gradio components
|
116 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|