File size: 2,908 Bytes
80fd191 665e653 80fd191 665e653 80fd191 cc6c61e 80fd191 665e653 80fd191 665e653 80fd191 665e653 c57634b 665e653 80fd191 665e653 80fd191 665e653 80fd191 665e653 80fd191 cc6c61e 80fd191 e19fd5f 4d6ff3b 80fd191 51557c9 80fd191 4d6ff3b 80fd191 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from ormbg import ORMBG
from PIL import Image
model_path = "ormbg.pth"
net = ORMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def inference(image):
# prepare input
orig_image = Image.fromarray(image)
w, h = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
gr.Markdown("## Open Remove Background Model (ormbg)")
gr.HTML(
"""
<p style="margin-bottom: 10px; font-size: 94%">
This is a demo for Open Remove Background Model (ormbg) that using
<a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
</p>
"""
)
title = "Open Remove Background Model (ormbg)"
description = r"""
This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
If you identify cases were the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>please contact me</a>!
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card: inference code</a>
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset: all images and backgrounds</a>
Known issues (work in progress):
- close-ups: from above, from below, profile, from side
- minor issues with hair segmentation when hair creates loops
- more various backgrounds needed
"""
examples = ["./example1.png", "./example2.png", "./example3.png"]
demo = gr.Interface(
fn=inference,
inputs="image",
outputs="image",
examples=examples,
title=title,
description=description,
)
if __name__ == "__main__":
demo.launch(share=False)
|