gpu-utils / app.py
not-lain's picture
refactor inpaint function to return composited image and update main function signature
e406805
raw
history blame
3.14 kB
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")
def prepare_image_and_mask(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
):
image = load_img(image).convert("RGB")
# expand image (left,top,right,bottom)
background = ImageOps.expand(
image,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
mask = Image.new("RGB", image.size, "black")
mask = ImageOps.expand(mask, border=(0, 20, 0, 0), fill="white")
return background, mask
def inpaint(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
prompt="",
):
background, mask = prepare_image_and_mask(
image, padding_top, padding_bottom, padding_left, padding_right
)
cnet_image = background.copy()
cnet_image.paste(0, (0, 0), mask)
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=28,
guidance_scale=30,
).images[0]
result = result.convert("RGBA")
cnet_image.paste(result, (0, 0), mask)
return cnet_image
def rmbg(image, url):
if image is None:
image = url
image = load_img(image).convert("RGB")
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
@spaces.GPU
def main(*args, progress=gr.Progress(track_tqdm=True)):
if len(args) == 2:
return rmbg(*args)
else:
return inpaint(*args)
rmbg_tab = gr.Interface(
fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
)
outpaint_tab = gr.Interface(
fn=main,
inputs=[
"image",
gr.Slider(label="padding top"),
gr.Slider(label="padding bottom"),
gr.Slider(label="padding left"),
gr.Slider(label="padding right"),
gr.Text(label="prompt"),
],
outputs=["image"],
api_name="outpainting",
)
demo = gr.TabbedInterface(
[rmbg_tab, outpaint_tab],
["remove background", "outpainting"],
title="Utilities that require GPU",
)
demo.launch()