File size: 3,725 Bytes
542c815
3f8e328
542c815
 
9bf688a
542c815
d6e753e
 
 
8a357d1
542c815
a0cbc48
 
 
eb330d2
a0cbc48
56b3741
 
a0cbc48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f91b95
8f942dd
542c815
 
 
 
 
 
 
a0cbc48
542c815
 
988f91c
 
542c815
 
 
56b3741
 
 
 
 
 
a0cbc48
542c815
 
 
 
 
 
 
 
 
 
70974c3
542c815
 
 
 
 
 
 
70974c3
542c815
 
 
949892a
542c815
d909bca
56b3741
 
 
 
d909bca
56b3741
 
 
 
 
 
 
 
 
 
 
 
 
 
acf2248
 
56b3741
 
 
 
d909bca
56b3741
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

from io import BytesIO
import base64
import re
import os

SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')

# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')

def readb64(b64):
    # Remove any data URI scheme prefix with regex
    b64 = data_uri_pattern.sub("", b64)
    # Decode and open the image with PIL
    img = Image.open(BytesIO(base64.b64decode(b64)))
    return img
    
# convert from PIL to base64
def writeb64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    b64image = base64.b64encode(buffered.getvalue())
    b64image_str = b64image.decode("utf-8")
    return b64image_str

net=BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net=net.cuda()
else:
    net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval() 


def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(secret_token, base64_in):
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')

    orig_image = readb64(base64_in)
    
    # prepare input
    w,h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    if torch.cuda.is_available():
        im_tensor=im_tensor.cuda()

    #inference
    result=net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)    
    # image to pil
    im_array = (result*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
    new_im.paste(orig_image, mask=pil_im)

    base64_out = writeb64(new_im)

    return base64_out


with gr.Blocks() as demo:
    secret_token = gr.Text(
        label='Secret Token',
        max_lines=1,
        placeholder='Enter your secret token')
    gr.HTML("""
        <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
        <div style="text-align: center; color: black;">
        <p style="color: black;">This space is a REST API to programmatically remove the background of an image.</p>
        <p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4" target="_blank">original space</a>, thank you!</p>
        </div>
        </div>""")
    base64_in = gr.Textbox(label="Base64 Input")
    base64_out = gr.Textbox(label="Base64 Output")
    submit_btn = gr.Button("Submit")
    submit_btn.click(
        fn=process,
        inputs=[secret_token, base64_in],
        outputs=base64_out,
        api_name="run")

demo.queue(max_size=20).launch()