import gradio as gr import cv2 import numpy as np from segment_anything import sam_model_registry, SamAutomaticMaskGenerator import base64 from PIL import Image from io import BytesIO import torch import clip # Load the segmentation model sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) # Load the CLIP model model, preprocess = clip.load("ViT-L/14@336px") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device).eval() def find_similarity(base64_image, text_input): try: # Decode the base64 image to bytes image_bytes = base64.b64decode(base64_image) # Convert the bytes to a PIL image image = Image.open(BytesIO(image_bytes)) # Preprocess the image image = preprocess(image).unsqueeze(0).to(device) # Prepare input text text_tokens = clip.tokenize([text_input]).to(device) # Encode image and text features with torch.no_grad(): image_features = model.encode_image(image) text_features = model.encode_text(text_tokens) # Normalize features and calculate similarity image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy() return similarity except Exception as e: return str(e) def segment_image(input_image, text_input): image_bytes = base64.b64decode(input_image) image = Image.open(BytesIO(image_bytes)) # Convert the image to RGB color mode image = image.convert("RGB") # Convert the image to a numpy array image = np.array(image) mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) segmented_regions = [] # List to store segmented regions with similarity scores for i, mask_dict in enumerate(masks): mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8) # Create a mask with the same shape as the original image mask = np.zeros_like(image) mask[:, :] = mask_data[:, :, np.newaxis] # Apply the mask to the original image segmented_region = cv2.bitwise_and(image, mask) x, y, w, h = map(int, mask_dict['bbox']) cropped_region = segmented_region[y:y+h, x:x+w] if not cropped_region.size: # If the cropped region is empty, return the input image as is return input_image # Convert to base64 image _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB)) segmented_image_base64 = base64.b64encode(buffer).decode() # Calculate similarity for the segmented image similarity = find_similarity(segmented_image_base64, text_input) # Append the segmented image and its similarity score segmented_regions.append({"image": segmented_image_base64, "similarity": similarity}) # Sort the segmented images by similarity in descending order segmented_regions.sort(key=lambda x: x["similarity"], reverse=True) # Limit the output to the top 6 key-value pairs segmented_regions = segmented_regions[:6] # Return the segmented images in descending order of similarity return segmented_regions # Create Gradio components input_image = gr.Textbox(label="Base64 Image", lines=8) text_input = gr.Textbox(label="Text Input") # Use Textbox with a label # Create a Gradio interface gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch()