File size: 2,648 Bytes
8ce07a1
 
 
 
 
 
 
 
b09a145
a505bd1
 
 
 
 
 
 
 
b09a145
797deb4
b09a145
 
 
 
 
797deb4
 
 
b09a145
797deb4
 
b09a145
2b13f4e
b09a145
 
2b13f4e
b09a145
2b13f4e
 
 
 
 
 
b09a145
797deb4
b09a145
 
 
 
 
797deb4
 
a505bd1
b09a145
797deb4
 
a505bd1
3145448
797deb4
a505bd1
797deb4
 
 
a505bd1
b09a145
797deb4
 
a505bd1
797deb4
 
 
a505bd1
797deb4
 
a505bd1
 
 
797deb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

# ๋ชจ๋ธ ๋กœ๋”ฉ์„ ํ•จ์ˆ˜ ๋‚ด๋กœ ์ด๋™ํ•˜์—ฌ GPU ํ• ๋‹น ์‹œ ๋กœ๋“œ๋˜๋„๋ก ์„ค์ •
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def fn(image):
    birefnet = AutoModelForImageSegmentation.from_pretrained(
        "ZhengPeng7/BiRefNet", trust_remote_code=True
    )
    birefnet.to("cuda")
    
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    processed_image = process(im, birefnet)
    return (processed_image, origin)

def process(image, model):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # ์˜ˆ์ธก
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image

@spaces.GPU
def process_file(f):
    birefnet = AutoModelForImageSegmentation.from_pretrained(
        "ZhengPeng7/BiRefNet", trust_remote_code=True
    )
    birefnet.to("cuda")
    
    name_path = f.rsplit(".", 1)[0] + ".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im, birefnet)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")

# ์˜ˆ์‹œ ์ด๋ฏธ์ง€
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")

demo = gr.TabbedInterface(
    [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)

if __name__ == "__main__":
    demo.launch(show_error=True)