Spaces:
Runtime error
Runtime error
File size: 3,965 Bytes
e617194 11e2057 1ea5bb8 b0637c4 e617194 b0637c4 e617194 b0637c4 e617194 1ea5bb8 e617194 b0637c4 e617194 b0637c4 e617194 1ea5bb8 e617194 b0637c4 e617194 b0637c4 e617194 1ea5bb8 b0637c4 e617194 1ea5bb8 e617194 b0637c4 e617194 b0637c4 e617194 1ea5bb8 e617194 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import os
import zipfile
import numpy as np
from PIL import Image
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
input_images = transform_image(im).unsqueeze(0).to("cuda")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
output_file_path = os.path.join("output_images", "output_image_single.png")
im.save(output_file_path)
output_path = os.path.join("output_images", "output_image_processed.png")
im.save(output_path, "PNG")
return [im, mask], output_path
def fn_url(url):
im = load_img(url, output_type="pil")
im = im.convert("RGB")
image_size = im.size
input_images = transform_image(im).unsqueeze(0).to("cuda")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
output_file_path = os.path.join("output_images", "output_image_url.png")
im.save(output_file_path)
output_path = os.path.join("output_images", "output_image_url_processed.png")
im.save(output_path, "PNG")
return [im, mask], output_path
def batch_fn(images):
output_paths = []
for idx, image_path in enumerate(images):
im = load_img(image_path, output_type="pil")
im = im.convert("RGB")
image_size = im.size
input_images = transform_image(im).unsqueeze(0).to("cuda")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
output_file_path = os.path.join("output_images", f"output_image_batch_{idx + 1}.png")
im.save(output_file_path)
output_paths.append(output_file_path)
zip_file_path = os.path.join("output_images", "processed_images.zip")
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
for file in output_paths:
zipf.write(file, os.path.basename(file))
return zip_file_path
batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple")
slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
chameleon = load_img("chameleon.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=[slider1, gr.File(label="PNG Output")], examples=[chameleon], api_name="image"
)
tab2 = gr.Interface(fn_url, inputs=text, outputs=[slider2, gr.File(label="PNG Output")], examples=[url], api_name="text")
tab3 = gr.Interface(
batch_fn,
inputs=batch_image,
outputs=gr.File(label="Download Processed Files"),
api_name="batch"
)
demo = gr.TabbedInterface(
[tab1, tab2, tab3], ["image", "text", "batch"], title="Multi Birefnet for Background Removal"
)
if __name__ == "__main__":
demo.launch()
|