user-agent's picture
Update app.py
0623b90 verified
# import gradio as gr
# import torch
# import uuid
# from PIL import Image
# from torchvision import transforms
# from transformers import AutoModelForImageSegmentation
# from typing import Union, List
# from loadimg import load_img # Your helper to load from URL or file
# torch.set_float32_matmul_precision("high")
# # Load BiRefNet model
# birefnet = AutoModelForImageSegmentation.from_pretrained(
# "ZhengPeng7/BiRefNet", trust_remote_code=True
# )
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# birefnet.to(device)
# # Image transformation
# transform_image = transforms.Compose([
# transforms.Resize((1024, 1024)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# ])
# def process(image: Image.Image) -> Image.Image:
# image_size = image.size
# input_tensor = transform_image(image).unsqueeze(0).to(device)
# with torch.no_grad():
# preds = birefnet(input_tensor)[-1].sigmoid().cpu()
# pred = preds[0].squeeze()
# mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
# binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
# white_bg = Image.new("RGB", image_size, (255, 255, 255))
# result = Image.composite(image, white_bg, binary_mask)
# return result
# def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
# results = []
# try:
# # Single image upload
# if image is not None:
# image = image.convert("RGB")
# processed = process(image)
# filename = f"output_{uuid.uuid4().hex[:8]}.png"
# processed.save(filename)
# return filename
# # Single image from URL
# if image_url:
# im = load_img(image_url, output_type="pil").convert("RGB")
# processed = process(im)
# filename = f"output_{uuid.uuid4().hex[:8]}.png"
# processed.save(filename)
# return filename
# # Batch of URLs
# if batch_urls:
# urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
# for url in urls:
# try:
# im = load_img(url, output_type="pil").convert("RGB")
# processed = process(im)
# filename = f"output_{uuid.uuid4().hex[:8]}.png"
# processed.save(filename)
# results.append(filename)
# except Exception as e:
# print(f"Error with {url}: {e}")
# return results if results else None
# except Exception as e:
# print("General error:", e)
# return None
# # Interface
# demo = gr.Interface(
# fn=handler,
# inputs=[
# gr.Image(label="Upload Image", type="pil"),
# gr.Textbox(label="Paste Image URL"),
# gr.Textbox(label="Comma-separated Image URLs (Batch)"),
# ],
# outputs=gr.File(label="Output File(s)", file_count="multiple"),
# title="Background Remover (White Fill)",
# description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
# )
# if __name__ == "__main__":
# demo.launch(show_error=True, mcp_server=True)
import gradio as gr
import torch
import uuid
import base64
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img # Your helper to load from URL or file
from io import BytesIO
torch.set_float32_matmul_precision("high")
# Load BiRefNet model
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
birefnet.to(device)
# Image transformation
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def load_image_from_data_url(data_url: str) -> Image.Image:
"""Load image from base64 data URL"""
if data_url.startswith("data:image/"):
# Extract base64 data after the comma
if "," in data_url:
header, encoded = data_url.split(",", 1)
image_data = base64.b64decode(encoded)
return Image.open(BytesIO(image_data))
else:
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
else:
# Regular URL, use existing load_img function
return load_img(data_url, output_type="pil")
def process(image: Image.Image) -> Image.Image:
image_size = image.size
input_tensor = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image, white_bg, binary_mask)
return result
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
results = []
try:
# Single image upload
if image is not None:
image = image.convert("RGB")
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
# Single image from URL (supports both regular URLs and base64 data URLs)
if image_url:
im = load_image_from_data_url(image_url).convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
# Batch of URLs (supports both regular URLs and base64 data URLs)
if batch_urls:
urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
for url in urls:
try:
im = load_image_from_data_url(url).convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
results.append(filename)
except Exception as e:
print(f"Error with {url}: {e}")
return results if results else None
except Exception as e:
print("General error:", e)
return None
# Interface
demo = gr.Interface(
fn=handler,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Paste Image URL"),
gr.Textbox(label="Comma-separated Image URLs (Batch)"),
],
outputs=gr.File(label="Output File(s)", file_count="multiple"),
title="Background Remover (White Fill)",
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True)