import gradio as gr import spaces import torch from loadimg import load_img from torchvision import transforms from transformers import AutoModelForImageSegmentation from diffusers import FluxFillPipeline from PIL import Image, ImageOps from sam2.sam2_image_predictor import SAM2ImagePredictor torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny", device="cpu") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda") def prepare_image_and_mask( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, ): image = load_img(image).convert("RGB") # expand image (left,top,right,bottom) background = ImageOps.expand( image, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) mask = Image.new("RGB", image.size, "black") mask = ImageOps.expand( mask, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) return background, mask def outpaint( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, prompt="", num_inference_steps=28, guidance_scale=50, ): background, mask = prepare_image_and_mask( image, padding_top, padding_bottom, padding_left, padding_right ) result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] result = result.convert("RGBA") return result def inpaint( image, mask, prompt="", num_inference_steps=28, guidance_scale=50, ): background = image.convert("RGB") mask = mask.convert("L") result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] result = result.convert("RGBA") return result def rmbg(image=None, url=None): if image is None: image = url image = load_img(image).convert("RGB") image_size = image.size input_images = transform_image(image).unsqueeze(0).to("cuda") # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image def mask_generation(image=None, json=None): return None @spaces.GPU def main(*args): api_num = args[0] args = args[1:] if api_num == 1: return rmbg(*args) elif api_num == 2: return outpaint(*args) elif api_num == 3: return inpaint(*args) elif api_num == 4: return mask_generation(*args) rmbg_tab = gr.Interface( fn=main, inputs=[ gr.Number(1, interactive=False), "image", gr.Text("", label="url"), ], outputs=["image"], api_name="rmbg", examples=[[1, "./assets/Inpainting mask.png", ""]], cache_examples=False, description="pass an image or a url of an image", ) outpaint_tab = gr.Interface( fn=main, inputs=[ gr.Number(2, interactive=False), gr.Image(label="image", type="pil"), gr.Number(label="padding top"), gr.Number(label="padding bottom"), gr.Number(label="padding left"), gr.Number(label="padding right"), gr.Text(label="prompt"), gr.Number(value=50, label="num_inference_steps"), gr.Number(value=28, label="guidance_scale"), ], outputs=["image"], api_name="outpainting", examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]], cache_examples=False, ) inpaint_tab = gr.Interface( fn=main, inputs=[ gr.Number(3, interactive=False), gr.Image(label="image", type="pil"), gr.Image(label="mask", type="pil"), gr.Text(label="prompt"), gr.Number(value=50, label="num_inference_steps"), gr.Number(value=28, label="guidance_scale"), ], outputs=["image"], api_name="inpaint", examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]], cache_examples=False, description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space", ) sam2_tab = gr.Interface( main, inputs=[ gr.Number(4, interactive=False), gr.Image("image", type="pil"), gr.JSON(), ], outputs=["image"], ) demo = gr.TabbedInterface( [rmbg_tab, outpaint_tab, inpaint_tab, sam2_tab], ["remove background", "outpainting", "inpainting", "sam2"], title="Utilities that require GPU", ) demo.launch()