File size: 3,629 Bytes
3fcc660
1e65fde
3fcc660
 
 
 
 
 
d2a1709
3fcc660
 
 
d2a1709
 
 
 
3fcc660
8239fe8
d2a1709
3fcc660
461565d
3fcc660
 
 
d2a1709
3fcc660
 
d8a5e1e
3fcc660
 
1e22887
3fcc660
 
d2a1709
461565d
1a1c89b
 
 
 
 
 
d2a1709
 
461565d
1a1c89b
461565d
3fcc660
461565d
 
 
d2a1709
461565d
d2a1709
3fcc660
d2a1709
3fcc660
 
 
 
d8a5e1e
3fcc660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461565d
 
3fcc660
 
 
 
 
 
4775a16
 
 
3fcc660
 
d2a1709
3fcc660
 
 
 
d2a1709
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import spaces
import torch
import uuid
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img

torch.set_float32_matmul_precision("high")

# Load RMBG v1.4 model
model = AutoModelForImageSegmentation.from_pretrained(
    "briaai/RMBG-1.4", 
    trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Transform for RMBG v1.4
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

@spaces.GPU
def process(image: Image.Image) -> Image.Image:
    image_size = image.size
    input_tensor = transform_image(image).unsqueeze(0).to(device)

    with torch.no_grad():
        preds = model(input_tensor)
        
        # Handle list output - extract the tensor from the list
        if isinstance(preds, list):
            # Usually the mask is the last or first element
            pred = preds[-1] if len(preds) > 0 else preds[0]
        elif isinstance(preds, tuple):
            pred = preds[0]
        else:
            pred = preds
            
        # Now apply sigmoid to the tensor
        mask = pred.sigmoid().cpu()

    # Process the mask
    mask_tensor = mask[0].squeeze()
    mask_pil = transforms.ToPILImage()(mask_tensor).resize(image_size).convert("L")
    
    # Create binary mask with threshold
    binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)

    # Apply mask with white background
    white_bg = Image.new("RGB", image_size, (255, 255, 255))
    result = Image.composite(image, white_bg, binary_mask)
    return result

@spaces.GPU
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
    results = []

    try:
        if image is not None:
            image = image.convert("RGB")
            processed = process(image)
            filename = f"output_{uuid.uuid4().hex[:8]}.png"
            processed.save(filename)
            return filename

        if image_url:
            im = load_img(image_url, output_type="pil").convert("RGB")
            processed = process(im)
            filename = f"output_{uuid.uuid4().hex[:8]}.png"
            processed.save(filename)
            return filename

        if batch_urls:
            urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
            for url in urls:
                try:
                    im = load_img(url, output_type="pil").convert("RGB")
                    processed = process(im)
                    filename = f"output_{uuid.uuid4().hex[:8]}.png"
                    processed.save(filename)
                    results.append(filename)
                except Exception as e:
                    print(f"Error with {url}: {e}")
            return results if results else None

    except Exception as e:
        print("General error:", e)
        import traceback
        traceback.print_exc()

    return None

demo = gr.Interface(
    fn=handler,
    inputs=[
        gr.Image(label="Upload Image", type="pil"),
        gr.Textbox(label="Paste Image URL"),
        gr.Textbox(label="Comma-separated Image URLs (Batch)"),
    ],
    outputs=gr.File(label="Output File(s)", file_count="multiple"),
    title="Background Remover (RMBG v1.4)",
    description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)

if __name__ == "__main__":
    demo.launch(show_error=True, mcp_server=True)