Spaces:
Sleeping
Sleeping
import gradio as gr | |
import cv2 | |
import numpy as np | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
import torch | |
import clip | |
# Load the segmentation model | |
sam_checkpoint = "sam_vit_h_4b8939.pth" | |
model_type = "vit_h" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
# Load the CLIP model | |
model, preprocess = clip.load("ViT-L/14@336px") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device).eval() | |
def find_similarity(base64_image, text_input): | |
try: | |
# Decode the base64 image to bytes | |
image_bytes = base64.b64decode(base64_image) | |
# Convert the bytes to a PIL image | |
image = Image.open(BytesIO(image_bytes)) | |
# Preprocess the image | |
image = preprocess(image).unsqueeze(0).to(device) | |
# Prepare input text | |
text_tokens = clip.tokenize([text_input]).to(device) | |
# Encode image and text features | |
with torch.no_grad(): | |
image_features = model.encode_image(image) | |
text_features = model.encode_text(text_tokens) | |
# Normalize features and calculate similarity | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy() | |
return similarity | |
except Exception as e: | |
return str(e) | |
def segment_image(input_image, text_input): | |
image_bytes = base64.b64decode(input_image) | |
image = Image.open(BytesIO(image_bytes)) | |
# Convert the image to RGB color mode | |
image = image.convert("RGB") | |
# Convert the image to a numpy array | |
image = np.array(image) | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
masks = mask_generator.generate(image) | |
segmented_regions = [] # List to store segmented regions with similarity scores | |
for i, mask_dict in enumerate(masks): | |
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8) | |
# Create a mask with the same shape as the original image | |
mask = np.zeros_like(image) | |
mask[:, :] = mask_data[:, :, np.newaxis] | |
# Apply the mask to the original image | |
segmented_region = cv2.bitwise_and(image, mask) | |
x, y, w, h = map(int, mask_dict['bbox']) | |
cropped_region = segmented_region[y:y+h, x:x+w] | |
if not cropped_region.size: | |
# If the cropped region is empty, return the input image as is | |
return input_image | |
# Convert to base64 image | |
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB)) | |
segmented_image_base64 = base64.b64encode(buffer).decode() | |
# Calculate similarity for the segmented image | |
similarity = find_similarity(segmented_image_base64, text_input) | |
# Append the segmented image and its similarity score | |
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity}) | |
# Sort the segmented images by similarity in descending order | |
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True) | |
# Limit the output to the top 6 key-value pairs | |
segmented_regions = segmented_regions[:6] | |
# Return the segmented images in descending order of similarity | |
return segmented_regions | |
# Create Gradio components | |
input_image = gr.Textbox(label="Base64 Image", lines=8) | |
text_input = gr.Textbox(label="Text Input") # Use Textbox with a label | |
# Create a Gradio interface | |
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch() | |