Spaces:
Running
Running
File size: 2,648 Bytes
8ce07a1 b09a145 a505bd1 b09a145 797deb4 b09a145 797deb4 b09a145 797deb4 b09a145 2b13f4e b09a145 2b13f4e b09a145 2b13f4e b09a145 797deb4 b09a145 797deb4 a505bd1 b09a145 797deb4 a505bd1 3145448 797deb4 a505bd1 797deb4 a505bd1 b09a145 797deb4 a505bd1 797deb4 a505bd1 797deb4 a505bd1 797deb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
# ๋ชจ๋ธ ๋ก๋ฉ์ ํจ์ ๋ด๋ก ์ด๋ํ์ฌ GPU ํ ๋น ์ ๋ก๋๋๋๋ก ์ค์
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU
def fn(image):
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im, birefnet)
return (processed_image, origin)
def process(image, model):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# ์์ธก
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
@spaces.GPU
def process_file(f):
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im, birefnet)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
# ์์ ์ด๋ฏธ์ง
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)
if __name__ == "__main__":
demo.launch(show_error=True) |