jbilcke-hf's picture
jbilcke-hf HF Staff
Update app.py
b76ec0d verified
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
from io import BytesIO
import base64
import re
import os
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
net=BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(secret_token, base64_in):
if secret_token != SECRET_TOKEN:
raise gr.Error(
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
orig_image = readb64(base64_in)
# prepare input
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
new_im.paste(orig_image, mask=pil_im)
base64_out = writeb64(new_im)
return base64_out
with gr.Blocks() as demo:
secret_token = gr.Text(
label='Secret Token',
max_lines=1,
placeholder='Enter your secret token')
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically remove the background of an image.</p>
<p style="color: black;">Interested in using it? Please use the <a href="https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4" target="_blank">original space</a>, thank you!</p>
</div>
</div>""")
base64_in = gr.Textbox(label="Base64 Input")
base64_out = gr.Textbox(label="Base64 Output")
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=process,
inputs=[secret_token, base64_in],
outputs=base64_out,
api_name="run")
demo.queue(max_size=20).launch()