Spaces:
Running
Running
File size: 2,452 Bytes
a505bd1 3145448 a505bd1 3145448 a505bd1 2b13f4e a505bd1 4e96c9f a505bd1 2b13f4e a505bd1 2b13f4e a505bd1 3145448 2b13f4e a505bd1 3145448 a505bd1 2b13f4e 3145448 a505bd1 3145448 a505bd1 2b13f4e 3145448 a505bd1 2b13f4e 3145448 543442b 3145448 543442b a505bd1 3145448 a505bd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os
# ๋ชจ๋ธ ๋ก๋ ๋ฐ CPU๋ก ์ค์
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu") # GPU -> CPU๋ก ๋ณ๊ฒฝ
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu") # CPU๋ก ๋ณ๊ฒฝ
# ์์ธก ์ํ
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
# JPG๋ก ๋ณํํ์ฌ ์ ์ฅ
jpg_image = origin.copy()
jpg_image = jpg_image.convert("RGB")
jpg_path = "output.jpg"
jpg_image.save(jpg_path, format="JPEG")
return [processed_image], jpg_path # ImageSlider๋ ๋ฆฌ์คํธ๋ฅผ ๊ธฐ๋ํจ
def process_file(f):
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
# Gradio ์ปดํฌ๋ํธ ์ ์
slider1 = ImageSlider(label="Processed Image", type="pil")
image_upload = gr.Image(label="Upload an image")
output_download = gr.File(label="Download JPG File")
# ์๋ก์ด ์ํ ์ด๋ฏธ์ง ์ถ๊ฐ (app.py์ ๋์ผํ ํด๋์ ์์นํด์ผ ํจ)
sample_images = ["1.png", "2.jpg", "3.png"]
# Gradio ์ธํฐํ์ด์ค ์ค์
tab1 = gr.Interface(
fn=fn,
inputs=image_upload,
outputs=[slider1, output_download],
examples=sample_images,
api_name="image"
)
demo = gr.Interface(
tab1,
title="Background Removal Tool",
description="์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด ๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ฅผ ํ์ธํ๊ณ JPG ํ์ผ๋ก ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค."
)
if __name__ == "__main__":
demo.launch(show_error=True)
|