Spaces:
Runtime error
Runtime error
File size: 3,725 Bytes
542c815 3f8e328 542c815 9bf688a 542c815 d6e753e 8a357d1 542c815 a0cbc48 eb330d2 a0cbc48 56b3741 a0cbc48 4f91b95 8f942dd 542c815 a0cbc48 542c815 988f91c 542c815 56b3741 a0cbc48 542c815 70974c3 542c815 70974c3 542c815 949892a 542c815 d909bca 56b3741 d909bca 56b3741 acf2248 56b3741 d909bca 56b3741 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
from io import BytesIO
import base64
import re
import os
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
net=BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(secret_token, base64_in):
if secret_token != SECRET_TOKEN:
raise gr.Error(
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
orig_image = readb64(base64_in)
# prepare input
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
new_im.paste(orig_image, mask=pil_im)
base64_out = writeb64(new_im)
return base64_out
with gr.Blocks() as demo:
secret_token = gr.Text(
label='Secret Token',
max_lines=1,
placeholder='Enter your secret token')
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically remove the background of an image.</p>
<p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4" target="_blank">original space</a>, thank you!</p>
</div>
</div>""")
base64_in = gr.Textbox(label="Base64 Input")
base64_out = gr.Textbox(label="Base64 Output")
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=process,
inputs=[secret_token, base64_in],
outputs=base64_out,
api_name="run")
demo.queue(max_size=20).launch() |