File size: 2,452 Bytes
a505bd1
 
 
3145448
a505bd1
 
 
3145448
 
a505bd1
2b13f4e
a505bd1
 
 
4e96c9f
a505bd1
2b13f4e
a505bd1
 
 
 
 
 
 
 
2b13f4e
 
 
 
 
 
 
 
 
 
 
 
a505bd1
 
 
 
 
3145448
 
 
 
 
 
 
2b13f4e
a505bd1
3145448
 
 
 
 
 
 
a505bd1
2b13f4e
3145448
a505bd1
3145448
a505bd1
2b13f4e
3145448
a505bd1
2b13f4e
3145448
543442b
 
3145448
543442b
 
 
a505bd1
3145448
 
 
 
a505bd1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os

# ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ CPU๋กœ ์„ค์ •
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")  # GPU -> CPU๋กœ ๋ณ€๊ฒฝ

# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cpu")  # CPU๋กœ ๋ณ€๊ฒฝ
    # ์˜ˆ์ธก ์ˆ˜ํ–‰
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    processed_image = process(im)
    
    # JPG๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ €์žฅ
    jpg_image = origin.copy()
    jpg_image = jpg_image.convert("RGB")
    jpg_path = "output.jpg"
    jpg_image.save(jpg_path, format="JPEG")
    
    return [processed_image], jpg_path  # ImageSlider๋Š” ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ธฐ๋Œ€ํ•จ

def process_file(f):
    name_path = f.rsplit(".", 1)[0] + ".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

# Gradio ์ปดํฌ๋„ŒํŠธ ์ •์˜
slider1 = ImageSlider(label="Processed Image", type="pil")
image_upload = gr.Image(label="Upload an image")
output_download = gr.File(label="Download JPG File")

# ์ƒˆ๋กœ์šด ์ƒ˜ํ”Œ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€ (app.py์™€ ๋™์ผํ•œ ํด๋”์— ์œ„์น˜ํ•ด์•ผ ํ•จ)
sample_images = ["1.png", "2.jpg", "3.png"]

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
tab1 = gr.Interface(
    fn=fn, 
    inputs=image_upload, 
    outputs=[slider1, output_download], 
    examples=sample_images, 
    api_name="image"
)

demo = gr.Interface(
    tab1, 
    title="Background Removal Tool", 
    description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•˜๊ณ  JPG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
)

if __name__ == "__main__":
    demo.launch(show_error=True)