Spaces:
Sleeping
Sleeping
# import gradio as gr | |
# import torch | |
# import uuid | |
# from PIL import Image | |
# from torchvision import transforms | |
# from transformers import AutoModelForImageSegmentation | |
# from typing import Union, List | |
# from loadimg import load_img # Your helper to load from URL or file | |
# torch.set_float32_matmul_precision("high") | |
# # Load BiRefNet model | |
# birefnet = AutoModelForImageSegmentation.from_pretrained( | |
# "ZhengPeng7/BiRefNet", trust_remote_code=True | |
# ) | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# birefnet.to(device) | |
# # Image transformation | |
# transform_image = transforms.Compose([ | |
# transforms.Resize((1024, 1024)), | |
# transforms.ToTensor(), | |
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
# ]) | |
# def process(image: Image.Image) -> Image.Image: | |
# image_size = image.size | |
# input_tensor = transform_image(image).unsqueeze(0).to(device) | |
# with torch.no_grad(): | |
# preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
# pred = preds[0].squeeze() | |
# mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") | |
# binary_mask = mask.point(lambda p: 255 if p > 127 else 0) | |
# white_bg = Image.new("RGB", image_size, (255, 255, 255)) | |
# result = Image.composite(image, white_bg, binary_mask) | |
# return result | |
# def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: | |
# results = [] | |
# try: | |
# # Single image upload | |
# if image is not None: | |
# image = image.convert("RGB") | |
# processed = process(image) | |
# filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
# processed.save(filename) | |
# return filename | |
# # Single image from URL | |
# if image_url: | |
# im = load_img(image_url, output_type="pil").convert("RGB") | |
# processed = process(im) | |
# filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
# processed.save(filename) | |
# return filename | |
# # Batch of URLs | |
# if batch_urls: | |
# urls = [u.strip() for u in batch_urls.split(",") if u.strip()] | |
# for url in urls: | |
# try: | |
# im = load_img(url, output_type="pil").convert("RGB") | |
# processed = process(im) | |
# filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
# processed.save(filename) | |
# results.append(filename) | |
# except Exception as e: | |
# print(f"Error with {url}: {e}") | |
# return results if results else None | |
# except Exception as e: | |
# print("General error:", e) | |
# return None | |
# # Interface | |
# demo = gr.Interface( | |
# fn=handler, | |
# inputs=[ | |
# gr.Image(label="Upload Image", type="pil"), | |
# gr.Textbox(label="Paste Image URL"), | |
# gr.Textbox(label="Comma-separated Image URLs (Batch)"), | |
# ], | |
# outputs=gr.File(label="Output File(s)", file_count="multiple"), | |
# title="Background Remover (White Fill)", | |
# description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", | |
# ) | |
# if __name__ == "__main__": | |
# demo.launch(show_error=True, mcp_server=True) | |
import gradio as gr | |
import torch | |
import uuid | |
import base64 | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
from typing import Union, List | |
from loadimg import load_img # Your helper to load from URL or file | |
from io import BytesIO | |
torch.set_float32_matmul_precision("high") | |
# Load BiRefNet model | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
birefnet.to(device) | |
# Image transformation | |
transform_image = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
def load_image_from_data_url(data_url: str) -> Image.Image: | |
"""Load image from base64 data URL""" | |
if data_url.startswith("data:image/"): | |
# Extract base64 data after the comma | |
if "," in data_url: | |
header, encoded = data_url.split(",", 1) | |
image_data = base64.b64decode(encoded) | |
return Image.open(BytesIO(image_data)) | |
else: | |
raise ValueError(f"Invalid data URL format: {data_url[:50]}...") | |
else: | |
# Regular URL, use existing load_img function | |
return load_img(data_url, output_type="pil") | |
def process(image: Image.Image) -> Image.Image: | |
image_size = image.size | |
input_tensor = transform_image(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") | |
binary_mask = mask.point(lambda p: 255 if p > 127 else 0) | |
white_bg = Image.new("RGB", image_size, (255, 255, 255)) | |
result = Image.composite(image, white_bg, binary_mask) | |
return result | |
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: | |
results = [] | |
try: | |
# Single image upload | |
if image is not None: | |
image = image.convert("RGB") | |
processed = process(image) | |
filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
processed.save(filename) | |
return filename | |
# Single image from URL (supports both regular URLs and base64 data URLs) | |
if image_url: | |
im = load_image_from_data_url(image_url).convert("RGB") | |
processed = process(im) | |
filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
processed.save(filename) | |
return filename | |
# Batch of URLs (supports both regular URLs and base64 data URLs) | |
if batch_urls: | |
urls = [u.strip() for u in batch_urls.split(",") if u.strip()] | |
for url in urls: | |
try: | |
im = load_image_from_data_url(url).convert("RGB") | |
processed = process(im) | |
filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
processed.save(filename) | |
results.append(filename) | |
except Exception as e: | |
print(f"Error with {url}: {e}") | |
return results if results else None | |
except Exception as e: | |
print("General error:", e) | |
return None | |
# Interface | |
demo = gr.Interface( | |
fn=handler, | |
inputs=[ | |
gr.Image(label="Upload Image", type="pil"), | |
gr.Textbox(label="Paste Image URL"), | |
gr.Textbox(label="Comma-separated Image URLs (Batch)"), | |
], | |
outputs=gr.File(label="Output File(s)", file_count="multiple"), | |
title="Background Remover (White Fill)", | |
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", | |
) | |
if __name__ == "__main__": | |
demo.launch(show_error=True, mcp_server=True) | |