Spaces:
Sleeping
Sleeping
File size: 4,403 Bytes
8ab94bb f21cc03 8ab94bb d666f15 50cad22 8ab94bb d666f15 8eb4c63 d666f15 8eb4c63 f21cc03 608f6fc f21cc03 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 608f6fc d666f15 f21cc03 608f6fc f21cc03 d666f15 f21cc03 cb3fbdc 74e5368 cb3fbdc f21cc03 8ea55d3 f21cc03 8ab94bb f21cc03 74e5368 8ab94bb f21cc03 8ab94bb 914eb81 d3f5c13 f21cc03 8ab94bb f21cc03 28d01e2 f21cc03 8ab94bb f21cc03 8ea55d3 257f0a7 e1a9402 50cad22 e1a9402 50cad22 8b4232b 50cad22 8b4232b 50cad22 8b4232b 257f0a7 f21cc03 50cad22 8ab94bb cf83b7c fbf28e6 8b4232b 8ab94bb 8b4232b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import base64
from PIL import Image
from io import BytesIO
import torch
import clip
from gradio_client import Client
# Load the segmentation model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
def find_similarity(base64_image, text_input):
try:
# Decode the base64 image to bytes
image_bytes = base64.b64decode(base64_image)
# Convert the bytes to a PIL image
image = Image.open(BytesIO(image_bytes))
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Prepare input text
text_tokens = clip.tokenize([text_input]).to(device)
# Encode image and text features
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text_tokens)
# Normalize features and calculate similarity
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
return similarity
except Exception as e:
return str(e)
def segment_image(input_image, text_input):
image_bytes = base64.b64decode(input_image)
image = Image.open(BytesIO(image_bytes))
# Convert the image to RGB color mode
image = image.convert("RGB")
# Convert the image to a numpy array
image = np.array(image)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
segmented_regions = [] # List to store segmented regions with similarity scores
for i, mask_dict in enumerate(masks):
mask_data = (mask_dict['segmentation'] * 255).astype(np.uint8)
# Create a mask with the same shape as the original image
mask = np.zeros_like(image)
mask[:, :] = mask_data[:, :, np.newaxis]
# Apply the mask to the original image
segmented_region = cv2.bitwise_and(image, mask)
x, y, w, h = map(int, mask_dict['bbox'])
cropped_region = segmented_region[y:y+h, x:x+w]
# Convert to base64 image
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
segmented_image_base64 = base64.b64encode(buffer).decode()
# Calculate similarity for the segmented image
similarity = find_similarity(segmented_image_base64, text_input)
# Append the segmented image and its similarity score
segmented_regions.append({"image": segmented_image_base64, "similarity": similarity})
# Sort the segmented images by similarity in descending order
segmented_regions.sort(key=lambda x: x["similarity"], reverse=True)
# Limit the output to the top 6 key-value pairs
segmented_regions = segmented_regions[:6]
# Run each segmented image through the inference API and store results
results_with_similarity = []
for result in segmented_regions:
image_base64 = result["image"]
similarity = result["similarity"]
# Make an API call with image_base64, get the API result
client = Client("https://ktllc-clip-model-inputbase64.hf.space/--replicas/dv889/")
api_result = client.predict(image_base64, text_input, api_name="/predict")
# Append the API result and similarity to the list
results_with_similarity.append({"api_result": api_result, "image": image_base64, "similarity": similarity})
results_with_similarity = results_with_similarity[:1]
# Print the top 6 results
# for result in results_with_similarity:
# print(result)
# Return the segmented images in descending order of similarity
return results_with_similarity
# Create Gradio components
input_image = gr.Textbox(label="Base64 Image", lines=8)
text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
output_images = gr.outputs.JSON()
# Create a Gradio interface
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs=output_images).launch()
|