File size: 3,243 Bytes
003d203
 
 
0e6c023
 
 
aa16383
dcfda89
0e6c023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
003d203
aa16383
 
 
 
 
 
 
dcfda89
 
 
 
aa16383
dcfda89
 
 
 
 
 
aa16383
dcfda89
5fb3d3c
 
 
 
 
aa16383
 
 
 
 
dcfda89
 
 
 
 
aa16383
 
dcfda89
aa16383
6c5427b
e406805
 
aa16383
 
dcfda89
 
 
 
aa16383
dcfda89
aa16383
 
 
 
e406805
 
 
aa16383
 
 
 
e7bef73
 
0e6c023
 
 
 
 
 
 
 
 
 
 
 
dcfda89
e406805
 
991bda2
e406805
 
aa16383
 
 
dcfda89
aa16383
 
 
dcfda89
 
 
 
 
 
 
 
 
 
 
aa16383
0e6c023
 
aa16383
 
94386db
0e6c023
 
003d203
0e6c023
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")


def prepare_image_and_mask(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
):
    image = load_img(image).convert("RGB")
    # expand image (left,top,right,bottom)
    background = ImageOps.expand(
        image,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    mask = Image.new("RGB", image.size, "black")
    mask = ImageOps.expand(
        mask,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    return background, mask


def inpaint(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
    prompt="",
):
    background, mask = prepare_image_and_mask(
        image, padding_top, padding_bottom, padding_left, padding_right
    )
    mask = mask.convert("L")
    cnet_image = background.copy()
    cnet_image.paste(0, (0, 0), mask)

    result = pipe(
        prompt=prompt,
        height=background.height,
        width=background.width,
        image=background,
        mask_image=mask,
        num_inference_steps=28,
        guidance_scale=30,
    ).images[0]

    result = result.convert("RGBA")
    cnet_image.paste(result, (0, 0), mask)

    return cnet_image


def rmbg(image, url):
    if image is None:
        image = url
    image = load_img(image).convert("RGB")
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image


@spaces.GPU
def main(*args, progress=gr.Progress(track_tqdm=True)):
    if len(args) == 2:
        return rmbg(*args)
    else:
        return inpaint(*args)


rmbg_tab = gr.Interface(
    fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
)

outpaint_tab = gr.Interface(
    fn=main,
    inputs=[
        "image",
        gr.Slider(label="padding top"),
        gr.Slider(label="padding bottom"),
        gr.Slider(label="padding left"),
        gr.Slider(label="padding right"),
        gr.Text(label="prompt"),
    ],
    outputs=["image"],
    api_name="outpainting",
)

demo = gr.TabbedInterface(
    [rmbg_tab, outpaint_tab],
    ["remove background", "outpainting"],
    title="Utilities that require GPU",
)


demo.launch()