Spaces:
Sleeping
Sleeping
File size: 3,401 Bytes
8ab94bb d666f15 f687418 8ab94bb d666f15 8eb4c63 d666f15 8eb4c63 d666f15 608f6fc d666f15 608f6fc 251edb4 d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc f687418 d666f15 8ea55d3 8ab94bb 8ea55d3 8ab94bb 8ea55d3 8ab94bb 8ea55d3 8ab94bb 8ea55d3 8ab94bb 8ea55d3 28d01e2 8ea55d3 8ab94bb 8ea55d3 f687418 8ea55d3 552a7da 8ab94bb cf83b7c fbf28e6 d666f15 8ab94bb d666f15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from PIL import Image
from io import BytesIO
import torch
import clip
import json
# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
def find_similarity(base64_image, text_input):
try:
# Decode the base64 image to bytes
image_bytes = base64.b64decode(base64_image)
# Convert the bytes to a PIL image
image = Image.open(BytesIO(image_bytes))
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Prepare input text
text_tokens = clip.tokenize([text_input]).to(device)
# Encode image and text features
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text_tokens)
# Normalize features and calculate similarity
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
return similarity
except Exception as e:
return json.dumps({"error": str(e)})
def segment_image(input_image, text_input):
try:
image_bytes = base64.b64decode(input_image)
image = Image.open(BytesIO(image_bytes))
# Convert the image to RGB color mode
image = image.convert("RGB")
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
segmented_regions = [] # List to store segmented regions with similarity scores
for i, mask_dict in enumerate(masks):
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
segmented_region = Image.fromarray(mask_data) # Convert mask to an image
x, y, w, h = map(int, mask_dict['bbox'])
cropped_region = image.crop((x, y, x + w, y + h))
# Convert to base64 image
buffered = BytesIO()
cropped_region.save(buffered, format="PNG")
segmented_image_base64 = base64.b64encode(buffered.getvalue()).decode()
# Calculate similarity for the segmented image
similarity = find_similarity(segmented_image_base64, text_input)
# Append the segmented image and its similarity score
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
# Sort the segmented images by similarity in descending order
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
# Return the segmented images in a JSON format
return json.dumps(segmented_regions)
except Exception as e:
return json.dumps({"error": str(e)})
# Create Gradio components
input_image = gr.Textbox(label="Base64 Image", lines=8)
text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
output_images = gr.outputs.JSON()
# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()
|