Spaces:
Running
Running
File size: 3,793 Bytes
d1bffba 3caebd7 d1bffba 1f10ad6 d1bffba 3caebd7 d1bffba 2677815 d1bffba 2677815 d1bffba 2677815 d1bffba 2677815 d1bffba 2677815 d1bffba 1f10ad6 d1bffba 2677815 d1bffba ff9f53e d1bffba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from turtle import title
import os
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch
import cv2
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()
def create_mask(image,image_mask,alpha=0.7):
mask = np.zeros_like(image)
# copy your image_mask to all dimensions (i.e. colors) of your image
for i in range(3):
mask[:,:,i] = image_mask.copy()
# apply the mask to your image
overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
return overlay_image
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
bbox = np.asarray(bbox)/model_shape
y1,y2 = bbox[::2] *orig_image_shape[0]
x1,x2 = bbox[1::2]*orig_image_shape[1]
return [int(y1),int(x1),int(y2),int(x2)]
def detect_using_clip(image,prompts=[],threshould=0.4):
h,w = image.shape[:2]
model_detections = dict()
predicted_images = dict()
inputs = processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
detection = outputs.logits[0] # Assuming class index 0
for i,prompt in enumerate(prompts):
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
predicted_image = np.where(predicted_image>threshould,np.random.randint(128,255),0)
# extract countours from the image
lbl_0 = label(predicted_image)
props = regionprops(lbl_0)
prompt = prompt.lower()
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
predicted_images[prompt]= predicted_image
return model_detections , predicted_images
def visualize_images(image,detections,predicted_images,prompt):
alpha = 0.7
# H,W = image.shape[:2]
prompt = prompt.lower()
image_resize = cv2.resize(image,(352,352))
mask_image = create_mask(image=image_resize,image_mask=predicted_images[prompt])
if prompt not in detections.keys():
print("prompt not in query ..")
return image_resize
final_image = cv2.addWeighted(image_resize,alpha,mask_image,1-alpha,0)
return final_image
def shot(image, labels_text,selected_categoty):
if "," in labels_text:
prompts = labels_text.split(',')
else:
prompts = [labels_text]
prompts = list(map(lambda x: x.strip(),prompts))
model_detections,predicted_images = detect_using_clip(image,prompts=prompts)
category_image = visualize_images(image=image,detections=model_detections,predicted_images=predicted_images,prompt=selected_categoty)
return category_image
iface = gr.Interface(fn=shot,
inputs = ["image","text","text"],
outputs = "image",
description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )",
title = "Zero-shot Image Segmentation with Prompt ",
examples=[
["images/room.jpg","bed, table, plant, light, window",'plant'],
["images/image2.png","banner, building,door, sign","sign"]
],
# allow_flagging=False,
# analytics_enabled=False,
)
iface.launch()
|