user-agent's picture
Update app.py
1a1c89b verified
raw
history blame
3.63 kB
import gradio as gr
import spaces
import torch
import uuid
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img
torch.set_float32_matmul_precision("high")
# Load RMBG v1.4 model
model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4",
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Transform for RMBG v1.4
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@spaces.GPU
def process(image: Image.Image) -> Image.Image:
image_size = image.size
input_tensor = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_tensor)
# Handle list output - extract the tensor from the list
if isinstance(preds, list):
# Usually the mask is the last or first element
pred = preds[-1] if len(preds) > 0 else preds[0]
elif isinstance(preds, tuple):
pred = preds[0]
else:
pred = preds
# Now apply sigmoid to the tensor
mask = pred.sigmoid().cpu()
# Process the mask
mask_tensor = mask[0].squeeze()
mask_pil = transforms.ToPILImage()(mask_tensor).resize(image_size).convert("L")
# Create binary mask with threshold
binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
# Apply mask with white background
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image, white_bg, binary_mask)
return result
@spaces.GPU
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
results = []
try:
if image is not None:
image = image.convert("RGB")
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
if image_url:
im = load_img(image_url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
if batch_urls:
urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
for url in urls:
try:
im = load_img(url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
results.append(filename)
except Exception as e:
print(f"Error with {url}: {e}")
return results if results else None
except Exception as e:
print("General error:", e)
import traceback
traceback.print_exc()
return None
demo = gr.Interface(
fn=handler,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Paste Image URL"),
gr.Textbox(label="Comma-separated Image URLs (Batch)"),
],
outputs=gr.File(label="Output File(s)", file_count="multiple"),
title="Background Remover (RMBG v1.4)",
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True)