File size: 3,227 Bytes
80fd191
 
 
 
 
 
 
 
665e653
80fd191
665e653
 
 
80fd191
cc6c61e
 
 
 
 
 
 
80fd191
665e653
 
 
 
 
80fd191
 
665e653
80fd191
665e653
c57634b
665e653
 
 
 
 
 
80fd191
665e653
80fd191
665e653
 
80fd191
665e653
 
 
 
 
 
 
 
 
 
 
 
80fd191
 
 
 
 
 
 
 
 
 
 
cc6c61e
80fd191
e19fd5f
efe4474
6dc15f4
 
 
efe4474
 
80fd191
85f55d7
 
 
51557c9
80fd191
4d6ff3b
80fd191
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from ormbg import ORMBG
from PIL import Image


model_path = "ormbg.pth"

net = ORMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net = net.cuda()
else:
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()


def resize_image(image):
    image = image.convert("RGB")
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def inference(image):

    # prepare input
    orig_image = Image.fromarray(image)
    w, h = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # inference
    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im


gr.Markdown("## Open Remove Background Model (ormbg)")
gr.HTML(
    """
  <p style="margin-bottom: 10px; font-size: 94%">
    This is a demo for Open Remove Background Model (ormbg) that using
    <a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
  </p>
"""
)
title = "Open Remove Background Model (ormbg)"
description = r"""
This model is a <strong>fully open-source background remover</strong> optimized for images with humans.

It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).

This is the first iteration of the model, so there will be improvements!
If you identify cases were the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!

- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
- <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements

"""
examples = ["./example1.png", "./example2.png", "./example3.png"]

demo = gr.Interface(
    fn=inference,
    inputs="image",
    outputs="image",
    examples=examples,
    title=title,
    description=description,
)

if __name__ == "__main__":
    demo.launch(share=False)