faceparser / app.py
leonelhs's picture
make service
0479042
raw
history blame
2.58 kB
import os
import gradio as gr
import numpy as np
import torch
from PIL import Image
from bisnet import BiSeNet
from huggingface_hub import snapshot_download
from utils import vis_parsing_maps, decode_segmentation_masks, image_to_tensor
os.system("pip freeze")
REPO_ID = "leonelhs/faceparser"
MODEL_NAME = "79999_iter.pth"
model = BiSeNet(n_classes=19)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
snapshot_folder = snapshot_download(repo_id=REPO_ID)
model_path = os.path.join(snapshot_folder, MODEL_NAME)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
def makeOverlay(image, mask):
prediction_mask = np.asarray(mask)
image = image.resize((512, 512), Image.BILINEAR)
dark_map, overlay = vis_parsing_maps(image, prediction_mask)
colormap = decode_segmentation_masks(dark_map)
return overlay, colormap
def makeMask(image):
with torch.no_grad():
image = image.resize((512, 512), Image.BILINEAR)
input_tensor = image_to_tensor(image)
input_tensor = torch.unsqueeze(input_tensor, 0)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
output = model(input_tensor)[0]
return output.squeeze(0).cpu().numpy().argmax(0)
def predict(image):
mask = predict(image)
overlay, colormap = makeOverlay(image, mask)
return overlay
title = "Face Parser"
description = r"""
## Image face parser for research
This is an implementation of <a href='https://github.com/zllrunning/face-parsing.PyTorch' target='_blank'>face-parsing.PyTorch</a>.
It has no any particular purpose than start research on AI models.
"""
article = r"""
Questions, doubts, comments, please email πŸ“§ `[email protected]`
This demo is running on a CPU, if you like this project please make us a donation to run on a GPU or just give us a <a href='https://github.com/leonelhs/zeroscratches/' target='_blank'>Github ⭐</a>
<a href="https://www.buymeacoffee.com/leonelhs"><img src="https://img.buymeacoffee.com/button-api/?text=Buy me a coffee&emoji=&slug=leonelhs&button_colour=FFDD00&font_colour=000000&font_family=Cookie&outline_colour=000000&coffee_colour=ffffff" /></a>
<center><img src='https://visitor-badge.glitch.me/badge?page_id=zeroscratches.visitor-badge' alt='visitor badge'></center>
"""
demo = gr.Interface(
predict, [
gr.Image(type="pil", label="Input"),
], [
gr.Image(type="numpy", label="Image face parsed")
],
title=title,
description=description,
article=article)
demo.queue().launch()