# import gradio as gr # import torch # import uuid # from PIL import Image # from torchvision import transforms # from transformers import AutoModelForImageSegmentation # from typing import Union, List # from loadimg import load_img # Your helper to load from URL or file # torch.set_float32_matmul_precision("high") # # Load BiRefNet model # birefnet = AutoModelForImageSegmentation.from_pretrained( # "ZhengPeng7/BiRefNet", trust_remote_code=True # ) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # birefnet.to(device) # # Image transformation # transform_image = transforms.Compose([ # transforms.Resize((1024, 1024)), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # ]) # def process(image: Image.Image) -> Image.Image: # image_size = image.size # input_tensor = transform_image(image).unsqueeze(0).to(device) # with torch.no_grad(): # preds = birefnet(input_tensor)[-1].sigmoid().cpu() # pred = preds[0].squeeze() # mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") # binary_mask = mask.point(lambda p: 255 if p > 127 else 0) # white_bg = Image.new("RGB", image_size, (255, 255, 255)) # result = Image.composite(image, white_bg, binary_mask) # return result # def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: # results = [] # try: # # Single image upload # if image is not None: # image = image.convert("RGB") # processed = process(image) # filename = f"output_{uuid.uuid4().hex[:8]}.png" # processed.save(filename) # return filename # # Single image from URL # if image_url: # im = load_img(image_url, output_type="pil").convert("RGB") # processed = process(im) # filename = f"output_{uuid.uuid4().hex[:8]}.png" # processed.save(filename) # return filename # # Batch of URLs # if batch_urls: # urls = [u.strip() for u in batch_urls.split(",") if u.strip()] # for url in urls: # try: # im = load_img(url, output_type="pil").convert("RGB") # processed = process(im) # filename = f"output_{uuid.uuid4().hex[:8]}.png" # processed.save(filename) # results.append(filename) # except Exception as e: # print(f"Error with {url}: {e}") # return results if results else None # except Exception as e: # print("General error:", e) # return None # # Interface # demo = gr.Interface( # fn=handler, # inputs=[ # gr.Image(label="Upload Image", type="pil"), # gr.Textbox(label="Paste Image URL"), # gr.Textbox(label="Comma-separated Image URLs (Batch)"), # ], # outputs=gr.File(label="Output File(s)", file_count="multiple"), # title="Background Remover (White Fill)", # description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", # ) # if __name__ == "__main__": # demo.launch(show_error=True, mcp_server=True) import gradio as gr import torch import uuid import base64 from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Union, List from loadimg import load_img # Your helper to load from URL or file from io import BytesIO torch.set_float32_matmul_precision("high") # Load BiRefNet model birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") birefnet.to(device) # Image transformation transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def load_image_from_data_url(data_url: str) -> Image.Image: """Load image from base64 data URL""" if data_url.startswith("data:image/"): # Extract base64 data after the comma if "," in data_url: header, encoded = data_url.split(",", 1) image_data = base64.b64decode(encoded) return Image.open(BytesIO(image_data)) else: raise ValueError(f"Invalid data URL format: {data_url[:50]}...") else: # Regular URL, use existing load_img function return load_img(data_url, output_type="pil") def process(image: Image.Image) -> Image.Image: image_size = image.size input_tensor = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = birefnet(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") binary_mask = mask.point(lambda p: 255 if p > 127 else 0) white_bg = Image.new("RGB", image_size, (255, 255, 255)) result = Image.composite(image, white_bg, binary_mask) return result def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: results = [] try: # Single image upload if image is not None: image = image.convert("RGB") processed = process(image) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename # Single image from URL (supports both regular URLs and base64 data URLs) if image_url: im = load_image_from_data_url(image_url).convert("RGB") processed = process(im) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename # Batch of URLs (supports both regular URLs and base64 data URLs) if batch_urls: urls = [u.strip() for u in batch_urls.split(",") if u.strip()] for url in urls: try: im = load_image_from_data_url(url).convert("RGB") processed = process(im) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) results.append(filename) except Exception as e: print(f"Error with {url}: {e}") return results if results else None except Exception as e: print("General error:", e) return None # Interface demo = gr.Interface( fn=handler, inputs=[ gr.Image(label="Upload Image", type="pil"), gr.Textbox(label="Paste Image URL"), gr.Textbox(label="Comma-separated Image URLs (Batch)"), ], outputs=gr.File(label="Output File(s)", file_count="multiple"), title="Background Remover (White Fill)", description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", ) if __name__ == "__main__": demo.launch(show_error=True, mcp_server=True)