File size: 3,149 Bytes
d1bffba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e6eeb
 
 
 
 
d1bffba
 
 
3caebd7
c2e6eeb
d1bffba
 
 
 
 
 
 
 
 
 
 
 
c2e6eeb
d1bffba
c2e6eeb
 
 
 
d1bffba
2677815
d1bffba
2677815
c2e6eeb
 
 
 
 
 
d1bffba
 
 
1f10ad6
 
 
 
d1bffba
2677815
c2e6eeb
 
 
d1bffba
 
 
c2e6eeb
d1bffba
ff9f53e
 
d1bffba
c2e6eeb
 
 
 
 
d1bffba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from turtle import title
import os 
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch 
import cv2 
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()

def create_rgb_mask(mask):
    color = tuple(np.random.choice(range(0,256), size=3))
    gray_3_channel = cv2.merge((mask, mask, mask))
    gray_3_channel[mask==255] = color
    return gray_3_channel.astype(np.uint8)


def detect_using_clip(image,prompts=[],threshould=0.4):
    h,w  = image.shape[:2]
    predicted_masks = list()
    inputs = processor(
        text=prompts,
        images=[image] * len(prompts),
        padding="max_length",
        return_tensors="pt",
    )
    with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)

    for i,prompt in enumerate(prompts):
        predicted_image =  torch.sigmoid(preds[i][0]).detach().cpu().numpy()
        predicted_image = np.where(predicted_image>threshould,255,0)

        predicted_masks.append(create_rgb_mask(predicted_image))
    return predicted_masks

def visualize_images(image,predicted_images):
    alpha = 0.7
    # H,W = image.shape[:2]
    prompt = prompt.lower()
    image_resize = cv2.resize(image,(352,352))
    resize_image_copy = image_resize.copy()

    for mask_image in predicted_images:
        resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)

    return cv2.convertScaleAbs(resize_image_copy, alpha=1.8, beta=15)     
    

def shot(image, labels_text,selected_categoty):
    if "," in labels_text:
        prompts = labels_text.split(',')
    else:
        prompts = [labels_text]
    prompts = list(map(lambda x: x.strip(),prompts))
    
    predicted_images  = detect_using_clip(image,prompts=prompts)

    category_image = visualize_images(image=image,predicted_images=predicted_images)
    return category_image

iface = gr.Interface(fn=shot,
                    inputs = ["image","text"],
                    outputs = "image",
                    description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )",
                    title = "Zero-shot Image Segmentation with Prompt ",
                    examples=[
                        ["images/room.jpg","bed, table, plant, light, window,light"],
                        ["images/image2.png","banner, building,door, sign,"],
                        ["images/seats.jpg","door,table,chairs"],
                        ["images/vegetables.jpg","carrot,radish,beans,potato,brnjal,basket"]
                        ["images/room2.jpg","door,platns,dog,coffe table,mug,pillow,table lamp,carpet,pictures,door,clock"]
                        ],
                    # allow_flagging=False, 
                    # analytics_enabled=False,
                )
iface.launch()