Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| import torch | |
| import clip | |
| import json | |
| # Load the segmentation model | |
| sam_checkpoint = "sam_vit_h_4b8939.pth" | |
| model_type = "vit_h" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| # Load the CLIP model | |
| model, preprocess = clip.load("ViT-B/32") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device).eval() | |
| def find_similarity(base64_image, text_input): | |
| try: | |
| # Decode the base64 image to bytes | |
| image_bytes = base64.b64decode(base64_image) | |
| # Convert the bytes to a PIL image | |
| image = Image.open(BytesIO(image_bytes)) | |
| # Preprocess the image | |
| image = preprocess(image).unsqueeze(0).to(device) | |
| # Prepare input text | |
| text_tokens = clip.tokenize([text_input]).to(device) | |
| # Encode image and text features | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image) | |
| text_features = model.encode_text(text_tokens) | |
| # Normalize features and calculate similarity | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy() | |
| return similarity | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| def segment_image(input_image, text_input): | |
| try: | |
| image_bytes = base64.b64decode(input_image) | |
| image = Image.open(BytesIO(image_bytes)) | |
| # Convert the image to RGB color mode | |
| image = image.convert("RGB") | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| masks = mask_generator.generate(image) | |
| segmented_regions = [] # List to store segmented regions with similarity scores | |
| for i, mask_dict in enumerate(masks): | |
| mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8) | |
| segmented_region = Image.fromarray(mask_data) # Convert mask to an image | |
| x, y, w, h = map(int, mask_dict['bbox']) | |
| cropped_region = image.crop((x, y, x + w, y + h)) | |
| # Convert to base64 image | |
| buffered = BytesIO() | |
| cropped_region.save(buffered, format="PNG") | |
| segmented_image_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Calculate similarity for the segmented image | |
| similarity = find_similarity(segmented_image_base64, text_input) | |
| # Append the segmented image and its similarity score | |
| segmented_regions.append({"image": segmented_image_base64, "similarity": similarity}) | |
| # Sort the segmented images by similarity in descending order | |
| segmented_regions.sort(key=lambda x: x["similarity"], reverse=True) | |
| # Return the segmented images in a JSON format | |
| return json.dumps(segmented_regions) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| # Create Gradio components | |
| input_image = gr.Textbox(label="Base64 Image", lines=8) | |
| text_input = gr.Textbox(label="Text Input") # Use Textbox with a label | |
| output_images = gr.outputs.JSON() | |
| # Create a Gradio interface | |
| gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch() | |