Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,629 Bytes
3fcc660 1e65fde 3fcc660 d2a1709 3fcc660 d2a1709 3fcc660 8239fe8 d2a1709 3fcc660 461565d 3fcc660 d2a1709 3fcc660 d8a5e1e 3fcc660 1e22887 3fcc660 d2a1709 461565d 1a1c89b d2a1709 461565d 1a1c89b 461565d 3fcc660 461565d d2a1709 461565d d2a1709 3fcc660 d2a1709 3fcc660 d8a5e1e 3fcc660 461565d 3fcc660 4775a16 3fcc660 d2a1709 3fcc660 d2a1709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import spaces
import torch
import uuid
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img
torch.set_float32_matmul_precision("high")
# Load RMBG v1.4 model
model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4",
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Transform for RMBG v1.4
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@spaces.GPU
def process(image: Image.Image) -> Image.Image:
image_size = image.size
input_tensor = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_tensor)
# Handle list output - extract the tensor from the list
if isinstance(preds, list):
# Usually the mask is the last or first element
pred = preds[-1] if len(preds) > 0 else preds[0]
elif isinstance(preds, tuple):
pred = preds[0]
else:
pred = preds
# Now apply sigmoid to the tensor
mask = pred.sigmoid().cpu()
# Process the mask
mask_tensor = mask[0].squeeze()
mask_pil = transforms.ToPILImage()(mask_tensor).resize(image_size).convert("L")
# Create binary mask with threshold
binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
# Apply mask with white background
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image, white_bg, binary_mask)
return result
@spaces.GPU
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
results = []
try:
if image is not None:
image = image.convert("RGB")
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
if image_url:
im = load_img(image_url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
if batch_urls:
urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
for url in urls:
try:
im = load_img(url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
results.append(filename)
except Exception as e:
print(f"Error with {url}: {e}")
return results if results else None
except Exception as e:
print("General error:", e)
import traceback
traceback.print_exc()
return None
demo = gr.Interface(
fn=handler,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Paste Image URL"),
gr.Textbox(label="Comma-separated Image URLs (Batch)"),
],
outputs=gr.File(label="Output File(s)", file_count="multiple"),
title="Background Remover (RMBG v1.4)",
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True)
|