File size: 3,196 Bytes
d1bffba
 
 
 
 
 
 
 
 
 
 
 
 
c2e6eeb
 
 
 
 
d1bffba
 
 
c2e6eeb
d1bffba
 
 
 
 
 
 
 
 
 
 
 
c2e6eeb
 
7c8f933
c2e6eeb
 
d367c2f
d1bffba
2677815
c2e6eeb
 
 
 
 
d367c2f
d1bffba
d367c2f
1f10ad6
 
 
 
d1bffba
c2e6eeb
 
d367c2f
d1bffba
 
 
d367c2f
 
 
 
 
 
d1bffba
ff9f53e
 
d1bffba
582506c
 
 
d1bffba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch 
import cv2 
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()

def create_rgb_mask(mask):
    color = tuple(np.random.choice(range(0,256), size=3))
    gray_3_channel = cv2.merge((mask, mask, mask))
    gray_3_channel[mask==255] = color
    return gray_3_channel.astype(np.uint8)


def detect_using_clip(image,prompts=[],threshould=0.4):
    predicted_masks = list()
    inputs = processor(
        text=prompts,
        images=[image] * len(prompts),
        padding="max_length",
        return_tensors="pt",
    )
    with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)

    for i,prompt in enumerate(prompts):
        predicted_image =  torch.sigmoid(preds[i][0]).detach().cpu().numpy()
        predicted_image = np.where(predicted_image>threshould,255,0)
        predicted_masks.append(create_rgb_mask(predicted_image))

    return predicted_masks

def visualize_images(image,predicted_images,brightness=15,contrast=1.8):
    alpha = 0.7
    image_resize = cv2.resize(image,(352,352))
    resize_image_copy = image_resize.copy()

    for mask_image in predicted_images:
        resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)

    return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness)     

def shot(brightness,contrast,image,labels_text):
    if "," in labels_text:
        prompts = labels_text.split(',')
    else:
        prompts = [labels_text]
    prompts = list(map(lambda x: x.strip(),prompts))
    predicted_images  = detect_using_clip(image,prompts=prompts)

    category_image = visualize_images(image=image,predicted_images=predicted_images,brightness=brightness,contrast=contrast)
    return category_image

iface = gr.Interface(fn=shot,
                    inputs = [
                        gr.Slider(5, 50, value=15, label="Brightness", info="Choose between 5 and 50"),
                        gr.Slider(1, 5, value=1.5, label="Contrast", info="Choose between 1 and 5"),
                        "image",
                        "text"
                        ],
                    outputs = "image",
                    description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )",
                    title = "Zero-shot Image Segmentation with Prompt ",
                    examples=[
                        [19,1.5,"images/seats.jpg","door,table,chairs"],
                        [20,1.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"],
                        [17,2,"images/room2.jpg","door, plants, dog, coffe table, table lamp, carpet, door"]
                        ],
                    # allow_flagging=False, 
                    # analytics_enabled=False,
                )
iface.launch()