File size: 2,667 Bytes
5de8b22
a3f48ee
 
 
 
 
 
 
 
a903ae4
 
 
 
 
 
a3f48ee
 
 
9281027
a3f48ee
a903ae4
a3f48ee
 
 
 
 
 
 
 
c2df784
 
 
 
a3f48ee
 
 
 
80aa27f
c2df784
 
5068bdc
a3f48ee
 
 
 
a903ae4
a3f48ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0eeab4
 
a3f48ee
 
 
 
 
 
a973e8e
a3f48ee
 
a973e8e
a3f48ee
4c550a6
a3f48ee
 
4c550a6
ce8130a
a3f48ee
 
 
c2df784
a3f48ee
 
 
a903ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

# 检查 CUDA 是否可用
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to(device)
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

output_folder = 'output_images'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image = process(im)    
    image_path = os.path.join(output_folder, "no_bg_image.png")
    image.save(image_path)
    return (image, origin), image_path

@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to(device)
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image
  
def process_file(f):
    name_path = f.rsplit(".",1)[0]+".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="RMBG-2.0", type="pil")
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")


chameleon = load_img("giraffe.jpg", output_type="pil")

url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"

tab1 = gr.Interface(
    fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
)

tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")


demo = gr.TabbedInterface(
    [tab1, tab2], ["input image", "input url"], title="RMBG-2.0 for background removal"
)

if __name__ == "__main__":
    demo.launch(share=True, show_error=True)