File size: 4,740 Bytes
5a3dfd3
b2fc897
2517e60
5a3dfd3
 
 
 
e2aae4e
5a3dfd3
6dcded2
b483613
 
5a3dfd3
69f51b6
 
 
 
 
b134327
0fdeab1
69f51b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b134327
69f51b6
86845a2
e3698f4
69f51b6
86845a2
69f51b6
0fdeab1
69f51b6
 
 
 
 
 
86845a2
69f51b6
 
 
1943daa
d37873a
 
 
6dcded2
be469ef
03898c7
5ddbbe2
03898c7
 
 
 
 
 
6dcded2
03898c7
6dcded2
69f51b6
c80a17d
9ba6616
93e825f
c80a17d
7dd5ac5
8ce9b88
da9ce74
8465efd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx")
os.system("pip install onnxruntime imageio")
import cv2
import paddlehub as hub
import gradio as gr
import torch
from PIL import Image, ImageOps
import numpy as np
import imageio
os.mkdir("data")
os.mkdir("dataout")
model = hub.Module(name='U2Net')
import cv2
import numpy as np
import onnxruntime
import torch
from PIL import Image
sess_options = onnxruntime.SessionOptions()
rmodel = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options)

# Source https://github.com/advimman/lama
def get_image(image):
    if isinstance(image, Image.Image):
        img = np.array(image)
    elif isinstance(image, np.ndarray):
        img = image.copy()
    else:
        raise Exception("Input image should be either PIL Image or numpy array!")

    if img.ndim == 3:
        img = np.transpose(img, (2, 0, 1))  # chw
    elif img.ndim == 2:
        img = img[np.newaxis, ...]

    assert img.ndim == 3

    img = img.astype(np.float32) / 255
    return img


def ceil_modulo(x, mod):
    if x % mod == 0:
        return x
    return (x // mod + 1) * mod


def scale_image(img, factor, interpolation=cv2.INTER_AREA):
    if img.shape[0] == 1:
        img = img[0]
    else:
        img = np.transpose(img, (1, 2, 0))

    img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)

    if img.ndim == 2:
        img = img[None, ...]
    else:
        img = np.transpose(img, (2, 0, 1))
    return img


def pad_img_to_modulo(img, mod):
    channels, height, width = img.shape
    out_height = ceil_modulo(height, mod)
    out_width = ceil_modulo(width, mod)
    return np.pad(
        img,
        ((0, 0), (0, out_height - height), (0, out_width - width)),
        mode="symmetric",
    )


def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
    out_image = get_image(image)
    out_mask = get_image(mask)

    if scale_factor is not None:
        out_image = scale_image(out_image, scale_factor)
        out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)

    if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
        out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
        out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)

    out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
    out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)

    out_mask = (out_mask > 0) * 1

    return out_image, out_mask


def predict(jpg, msk):


    imagex = Image.open(jpg)
    mask = Image.open(msk).convert("L")

    image, mask = prepare_img_and_mask(imagex.resize((512, 512)), mask.resize((512, 512)), 'cpu')
    # Run the model
    outputs = rmodel.run(None, {'image': image.numpy().astype(np.float32), 'mask': mask.numpy().astype(np.float32)})

    output = outputs[0][0]
    # Postprocess the outputs
    output = output.transpose(1, 2, 0)
    output = output.astype(np.uint8)
    output = Image.fromarray(output)
    output = output.resize(imagex.size)
    output.save("/home/user/app/dataout/data_mask.png")


def infer(img,option):
  print(type(img))
  print(type(img["image"]))
  print(type(img["mask"]))
  imageio.imwrite("./data/data.png", img["image"])
  if option == "automatic (U2net)":
      result = model.Segmentation(
          images=[cv2.cvtColor(img["image"], cv2.COLOR_RGB2BGR)],
          paths=None,
          batch_size=1,
          input_size=320,
          output_dir='output',
          visualization=True)
      im = Image.fromarray(result[0]['mask'])
      im.save("./data/data_mask.png")
  else:
      imageio.imwrite("./data/data_mask.png", img["mask"])
  predict("./data/data.png", "./data/data_mask.png")    
  return "./dataout/data_mask.png","./data/data_mask.png"
  
inputs = [gr.Image(label="Input",type="numpy"),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")]
outputs = [gr.outputs.Image(type="file",label="output"),gr.outputs.Image(type="file",label="Mask")]
title = "LaMa Image Inpainting (using ONNX model from Carve))"
description = "Gradio demo for LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Masks are generated by U^2net"
article = "<p style='text-align: center'><a href='https://huggingface.co/Carve/LaMa-ONNX' target='_blank'>ONNX model ported by Carve.Photos</a> | <a href='https://github.com/saic-mdal/lama' target='_blank'>LaMa github repo</a></p>"
gr.Interface(infer, inputs, outputs, title=title, description=description, article=article).launch()