Spaces:
Sleeping
Sleeping
File size: 7,165 Bytes
3d3a8e1 3fcc660 3d3a8e1 3fcc660 3d3a8e1 3fcc660 8239fe8 3fcc660 3d3a8e1 0623b90 3d3a8e1 3fcc660 1e22887 3fcc660 3d3a8e1 3fcc660 3d3a8e1 3fcc660 3d3a8e1 3fcc660 3d3a8e1 3fcc660 4775a16 3fcc660 3d3a8e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# import gradio as gr
# import torch
# import uuid
# from PIL import Image
# from torchvision import transforms
# from transformers import AutoModelForImageSegmentation
# from typing import Union, List
# from loadimg import load_img # Your helper to load from URL or file
# torch.set_float32_matmul_precision("high")
# # Load BiRefNet model
# birefnet = AutoModelForImageSegmentation.from_pretrained(
# "ZhengPeng7/BiRefNet", trust_remote_code=True
# )
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# birefnet.to(device)
# # Image transformation
# transform_image = transforms.Compose([
# transforms.Resize((1024, 1024)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# ])
# def process(image: Image.Image) -> Image.Image:
# image_size = image.size
# input_tensor = transform_image(image).unsqueeze(0).to(device)
# with torch.no_grad():
# preds = birefnet(input_tensor)[-1].sigmoid().cpu()
# pred = preds[0].squeeze()
# mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
# binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
# white_bg = Image.new("RGB", image_size, (255, 255, 255))
# result = Image.composite(image, white_bg, binary_mask)
# return result
# def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
# results = []
# try:
# # Single image upload
# if image is not None:
# image = image.convert("RGB")
# processed = process(image)
# filename = f"output_{uuid.uuid4().hex[:8]}.png"
# processed.save(filename)
# return filename
# # Single image from URL
# if image_url:
# im = load_img(image_url, output_type="pil").convert("RGB")
# processed = process(im)
# filename = f"output_{uuid.uuid4().hex[:8]}.png"
# processed.save(filename)
# return filename
# # Batch of URLs
# if batch_urls:
# urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
# for url in urls:
# try:
# im = load_img(url, output_type="pil").convert("RGB")
# processed = process(im)
# filename = f"output_{uuid.uuid4().hex[:8]}.png"
# processed.save(filename)
# results.append(filename)
# except Exception as e:
# print(f"Error with {url}: {e}")
# return results if results else None
# except Exception as e:
# print("General error:", e)
# return None
# # Interface
# demo = gr.Interface(
# fn=handler,
# inputs=[
# gr.Image(label="Upload Image", type="pil"),
# gr.Textbox(label="Paste Image URL"),
# gr.Textbox(label="Comma-separated Image URLs (Batch)"),
# ],
# outputs=gr.File(label="Output File(s)", file_count="multiple"),
# title="Background Remover (White Fill)",
# description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
# )
# if __name__ == "__main__":
# demo.launch(show_error=True, mcp_server=True)
import gradio as gr
import torch
import uuid
import base64
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img # Your helper to load from URL or file
from io import BytesIO
torch.set_float32_matmul_precision("high")
# Load BiRefNet model
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
birefnet.to(device)
# Image transformation
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def load_image_from_data_url(data_url: str) -> Image.Image:
"""Load image from base64 data URL"""
if data_url.startswith("data:image/"):
# Extract base64 data after the comma
if "," in data_url:
header, encoded = data_url.split(",", 1)
image_data = base64.b64decode(encoded)
return Image.open(BytesIO(image_data))
else:
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
else:
# Regular URL, use existing load_img function
return load_img(data_url, output_type="pil")
def process(image: Image.Image) -> Image.Image:
image_size = image.size
input_tensor = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image, white_bg, binary_mask)
return result
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
results = []
try:
# Single image upload
if image is not None:
image = image.convert("RGB")
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
# Single image from URL (supports both regular URLs and base64 data URLs)
if image_url:
im = load_image_from_data_url(image_url).convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
# Batch of URLs (supports both regular URLs and base64 data URLs)
if batch_urls:
urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
for url in urls:
try:
im = load_image_from_data_url(url).convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
results.append(filename)
except Exception as e:
print(f"Error with {url}: {e}")
return results if results else None
except Exception as e:
print("General error:", e)
return None
# Interface
demo = gr.Interface(
fn=handler,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Paste Image URL"),
gr.Textbox(label="Comma-separated Image URLs (Batch)"),
],
outputs=gr.File(label="Output File(s)", file_count="multiple"),
title="Background Remover (White Fill)",
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True)
|