from turtle import title import os import gradio as gr from transformers import pipeline import numpy as np from PIL import Image import torch import cv2 from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig from skimage.measure import label, regionprops processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") classes = list() def create_rgb_mask(mask): color = tuple(np.random.choice(range(0,256), size=3)) gray_3_channel = cv2.merge((mask, mask, mask)) gray_3_channel[mask==255] = color return gray_3_channel.astype(np.uint8) def detect_using_clip(image,prompts=[],threshould=0.4): h,w = image.shape[:2] predicted_masks = list() inputs = processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) for i,prompt in enumerate(prompts): predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_image = np.where(predicted_image>threshould,255,0) predicted_masks.append(create_rgb_mask(predicted_image)) return predicted_masks def visualize_images(image,predicted_images): alpha = 0.7 # H,W = image.shape[:2] prompt = prompt.lower() image_resize = cv2.resize(image,(352,352)) resize_image_copy = image_resize.copy() for mask_image in predicted_images: resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10) return cv2.convertScaleAbs(resize_image_copy, alpha=1.8, beta=15) def shot(image, labels_text,selected_categoty): if "," in labels_text: prompts = labels_text.split(',') else: prompts = [labels_text] prompts = list(map(lambda x: x.strip(),prompts)) predicted_images = detect_using_clip(image,prompts=prompts) category_image = visualize_images(image=image,predicted_images=predicted_images) return category_image iface = gr.Interface(fn=shot, inputs = ["image","text"], outputs = "image", description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )", title = "Zero-shot Image Segmentation with Prompt ", examples=[ ["images/room.jpg","bed, table, plant, light, window,light"], ["images/image2.png","banner, building,door, sign,"], ["images/seats.jpg","door,table,chairs"], ["images/vegetables.jpg","carrot,radish,beans,potato,brnjal,basket"] ["images/room2.jpg","door,platns,dog,coffe table,mug,pillow,table lamp,carpet,pictures,door,clock"] ], # allow_flagging=False, # analytics_enabled=False, ) iface.launch()