File size: 2,287 Bytes
c0c08a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import pipeline
import gradio
from PIL import Image
from IPython.display import display, HTML
import base64
from PIL import Image
from io import BytesIO
from sentence_transformers import SentenceTransformer, util


backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def getImageDetails(image) -> dict:
    person = PersonPipe(image)
    bg = backgroundPipe(image)
    ret = {}
    labs = []
    for imask in bg:
        ret[imask["label"]] = imask["mask"] # Apply base64 image converter here if needed
        labs.append(imask["label"])
    for mask in person:
        ret[mask["label"]] = mask["mask"] # Apply base64 image converter here if needed
        labs.append(mask["label"])
    return ret, labs

def processSentence(sentence: str, semilist: list):
    query_embedding = sentenceModal.encode(sentence)
    passage_embedding = sentenceModal.encode(semilist)
    listv = util.dot_score(query_embedding, passage_embedding)[0]
    float_list = []
    for i in listv:
        float_list.append(i)
    max_value = max(float_list)
    max_index = float_list.index(max_value)
    return semilist[max_index]

def process_image(image):
    rgba_image = image.convert("RGBA")
    switched_data = [
        (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in rgba_image.getdata()
    ]
    switched_image = Image.new("RGBA", rgba_image.size)
    switched_image.putdata(switched_data)
    final_data = [
        (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in switched_image.getdata()
    ]
    processed_image = Image.new("RGBA", rgba_image.size)
    processed_image.putdata(final_data)
    return processed_image

def processAndGetMask(image: str, text: str):
    datas, labs = getImageDetails(image)
    selector = processSentence(text, labs)
    imageout = datas[selector]
    return process_image(imageout)

gr = gradio.Interface(
    processAndGetMask,
    [gradio.Image(type="pil"), gradio.Text()],
    gradio.Image(type="pil")
)
gr.launch()