Kims12's picture
Update app.py
2b13f4e verified
raw
history blame
2.45 kB
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os
# ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ CPU๋กœ ์„ค์ •
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu") # GPU -> CPU๋กœ ๋ณ€๊ฒฝ
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu") # CPU๋กœ ๋ณ€๊ฒฝ
# ์˜ˆ์ธก ์ˆ˜ํ–‰
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
# JPG๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ €์žฅ
jpg_image = origin.copy()
jpg_image = jpg_image.convert("RGB")
jpg_path = "output.jpg"
jpg_image.save(jpg_path, format="JPEG")
return [processed_image], jpg_path # ImageSlider๋Š” ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ธฐ๋Œ€ํ•จ
def process_file(f):
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
# Gradio ์ปดํฌ๋„ŒํŠธ ์ •์˜
slider1 = ImageSlider(label="Processed Image", type="pil")
image_upload = gr.Image(label="Upload an image")
output_download = gr.File(label="Download JPG File")
# ์ƒˆ๋กœ์šด ์ƒ˜ํ”Œ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€ (app.py์™€ ๋™์ผํ•œ ํด๋”์— ์œ„์น˜ํ•ด์•ผ ํ•จ)
sample_images = ["1.png", "2.jpg", "3.png"]
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
tab1 = gr.Interface(
fn=fn,
inputs=image_upload,
outputs=[slider1, output_download],
examples=sample_images,
api_name="image"
)
demo = gr.Interface(
tab1,
title="Background Removal Tool",
description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•˜๊ณ  JPG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
)
if __name__ == "__main__":
demo.launch(show_error=True)