ktllc's picture
Update app.py
914eb81
raw
history blame
4.82 kB
import gradio as gr
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from PIL import Image
from io import BytesIO
import torch
import clip
# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
def find_similarity(base64_image, text_input):
try:
# Decode the base64 image to bytes
image_bytes = base64.b64decode(base64_image)
# Convert the bytes to a PIL image
image = Image.open(BytesIO(image_bytes))
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Prepare input text
text_tokens = clip.tokenize([text_input]).to(device)
# Encode image and text features
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text_tokens)
# Normalize features and calculate similarity
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
return similarity
except Exception as e:
return str(e)
# Define a function for image segmentation
def segment_image(input_image, text_input):
image_bytes = base64.b64decode(input_image)
image = Image.open(BytesIO(image_bytes))
# Convert the image to a numpy array
image = np.array(image)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
segmented_regions = [] # List to store segmented regions with similarity scores
for i, mask_dict in enumerate(masks):
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
# Create a mask with the same shape as the original image
mask = np.zeros_like(image)
mask[:, :] = mask_data[:, :, np.newaxis]
# Apply the mask to the original image
segmented_region = cv2.bitwise_and(image, mask)
x, y, w, h = map(int, mask_dict['bbox'])
cropped_region = segmented_region[y:y+h, x:x+w]
# Convert to base64 image
_, buffer = cv2.imencode(".png", cropped_region)
segmented_image_base64 = base64.b64encode(buffer).decode()
# Calculate similarity for the segmented image
similarity = find_similarity(segmented_image_base64, text_input)
# Append the segmented image and its similarity score
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
# Sort the segmented images by similarity in descending order
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
# Return the segmented images in descending order of similarity
return segmented_regions
# def segment_image(input_image, text_input):
# image_bytes = base64.b64decode(input_image)
# image = Image.open(BytesIO(image_bytes))
# image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
# mask_generator = SamAutomaticMaskGenerator(sam)
# masks = mask_generator.generate(image)
# segmented_regions = [] # List to store segmented regions with similarity scores
# for i, mask_dict in enumerate(masks):
# mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
# segmented_region = cv2.bitwise_and(image, image, mask=mask_data)
# x, y, w, h = map(int, mask_dict['bbox'])
# cropped_region = segmented_region[y:y+h, x:x+w]
# # Convert to base64 image
# _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
# segmented_image_base64 = base64.b64encode(buffer).decode()
# # Calculate similarity for the segmented image
# similarity = find_similarity(segmented_image_base64, text_input)
# # Append the segmented image and its similarity score
# segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
# # Sort the segmented images by similarity in descending order
# segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
# # Return the segmented images in descending order of similarity
# return segmented_regions
# Create Gradio components
input_image = gr.Textbox(label="Base64 Image", lines=8)
text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
output_images = gr.outputs.JSON()
# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()