Spaces:
Sleeping
Sleeping
import gradio as gr | |
import cv2 | |
import numpy as np | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
import base64 | |
import requests | |
from gradio_client import Client | |
# Load the segmentation model | |
sam_checkpoint = "sam_vit_h_4b8939.pth" | |
model_type = "vit_h" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
# Define a function for image segmentation | |
def segment_image(input_image): | |
image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
masks = mask_generator.generate(image) | |
segmented_regions = [] # List to store segmented regions | |
for i, mask_dict in enumerate(masks): | |
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8) | |
segmented_region = cv2.bitwise_and(input_image, input_image, mask=mask_data) | |
x, y, w, h = map(int, mask_dict['bbox']) | |
cropped_region = segmented_region[y:y+h, x:x+w] | |
# Convert to base64 image | |
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB)) | |
segmented_image_base64 = base64.b64encode(buffer).decode() | |
segmented_regions.append(segmented_image_base64) # Add to the list | |
return segmented_regions | |
# Function to call the API and calculate cosine similarity | |
def calculate_cosine_similarity(segmented_images): | |
highest_cosine = -1 | |
highest_cosine_base64 = "" | |
client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/mmz7z/") | |
for base64_image in segmented_images: | |
# Call the API here using the base64 image | |
result = client.predict(base64_image, base64_image, api_name="/predict") | |
cosine_value = result['similarity'] | |
print(f"Base64 Image: {base64_image}, Cosine Similarity: {cosine_value}") | |
if cosine_value > highest_cosine: | |
highest_cosine = cosine_value | |
highest_cosine_base64 = base64_image | |
print(f"Highest Cosine Similarity: {highest_cosine} (Base64 Image: {highest_cosine_base64})") | |
# Create Gradio components | |
input_image = gr.inputs.Image() | |
output_images = gr.outputs.JSON() | |
# Create a Gradio interface | |
segmentation_interface = gr.Interface(fn=segment_image, inputs=input_image, outputs=output_images) | |
# Launch the segmentation interface | |
segmentation_interface.launch() | |
# Get the segmented images from the segmentation interface | |
segmented_images = segmentation_interface.run() | |
segmentation_interface.close() | |
# Call the API for each segmented image and calculate cosine similarity | |
calculate_cosine_similarity(segmented_images) | |