Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	File size: 3,196 Bytes
			
			| 8ab94bb d666f15 8ab94bb d666f15 8ab94bb d666f15 8eb4c63 d666f15 8eb4c63 c0b02e4 d666f15 608f6fc d666f15 608f6fc 251edb4 d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 cf83b7c 7df14c5 d666f15 8ab94bb 7df14c5 d666f15 8ab94bb 2d7e30e 8ab94bb 7df14c5 8ab94bb d666f15 8ab94bb d666f15 28d01e2 d666f15 8ab94bb d666f15 8ab94bb cf83b7c fbf28e6 d666f15 8ab94bb d666f15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | import gradio as gr
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from PIL import Image
from io import BytesIO
import torch
import clip
# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
def find_similarity(base64_image, text_input):
    try:
        # Decode the base64 image to bytes
        image_bytes = base64.b64decode(base64_image)
        # Convert the bytes to a PIL image
        image = Image.open(BytesIO(image_bytes))
        # Preprocess the image
        image = preprocess(image).unsqueeze(0).to(device)
        # Prepare input text
        text_tokens = clip.tokenize([text_input]).to(device)
        # Encode image and text features
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text_tokens)
        # Normalize features and calculate similarity
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
        return similarity
    except Exception as e:
        return str(e)
# Define a function for image segmentation
def segment_image(input_image, text_input):
    image_bytes = base64.b64decode(input_image)
    image = Image.open(BytesIO(image_bytes))
    
    image = np.array(image)  # Remove the color mode conversion
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    
    segmented_regions = []  # List to store segmented regions with similarity scores
    for i, mask_dict in enumerate(masks):
        mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
        segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
        x, y, w, h = map(int, mask_dict['bbox'])
        cropped_region = segmented_region[y:y+h, x:x+w]
        # Convert to base64 image
        _, buffer = cv2.imencode(".png", cropped_region)
        segmented_image_base64 = base64.b64encode(buffer).decode()
        # Calculate similarity for the segmented image
        similarity = find_similarity(segmented_image_base64, text_input)
        # Append the segmented image and its similarity score
        segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
    # Sort the segmented images by similarity in descending order
    segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
    # Return the segmented images in descending order of similarity
    return segmented_regions
# Create Gradio components
input_image = gr.Textbox(label="Base64 Image", lines=8)
text_input = gr.Textbox(label="Text Input")  # Use Textbox with a label
output_images = gr.outputs.JSON()
# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()
 |