SelectByText / app.py
5m4ck3r's picture
Update app.py
4597d1f
raw
history blame
2.22 kB
from transformers import pipeline
import gradio
import base64
from PIL import Image
from io import BytesIO
from sentence_transformers import SentenceTransformer, util
backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def getImageDetails(image) -> dict:
person = PersonPipe(image)
bg = backgroundPipe(image)
ret = {}
labs = []
for imask in bg:
ret[imask["label"]] = imask["mask"] # Apply base64 image converter here if needed
labs.append(imask["label"])
for mask in person:
ret[mask["label"]] = mask["mask"] # Apply base64 image converter here if needed
labs.append(mask["label"])
return ret, labs
def processSentence(sentence: str, semilist: list):
query_embedding = sentenceModal.encode(sentence)
passage_embedding = sentenceModal.encode(semilist)
listv = util.dot_score(query_embedding, passage_embedding)[0]
float_list = []
for i in listv:
float_list.append(i)
max_value = max(float_list)
max_index = float_list.index(max_value)
return semilist[max_index]
def process_image(image):
rgba_image = image.convert("RGBA")
switched_data = [
(255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
for pixel in rgba_image.getdata()
]
switched_image = Image.new("RGBA", rgba_image.size)
switched_image.putdata(switched_data)
final_data = [
(0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
for pixel in switched_image.getdata()
]
processed_image = Image.new("RGBA", rgba_image.size)
processed_image.putdata(final_data)
return processed_image
def processAndGetMask(image: str, text: str):
datas, labs = getImageDetails(image)
selector = processSentence(text, labs)
imageout = datas[selector]
return process_image(imageout)
gr = gradio.Interface(
processAndGetMask,
[gradio.Image(type="pil"), gradio.Text()],
gradio.Image(type="pil")
)
gr.launch()