anodev commited on
Commit
24e5a6a
·
verified ·
1 Parent(s): 93e825f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -126
app.py CHANGED
@@ -1,138 +1,100 @@
1
  import os
2
- os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx")
3
- os.system("pip install onnxruntime imageio")
4
- import cv2
5
- import paddlehub as hub
6
- import gradio as gr
7
- import torch
8
- from PIL import Image, ImageOps
9
- import numpy as np
10
  import imageio
11
- os.mkdir("data")
12
- os.mkdir("dataout")
13
- model = hub.Module(name='U2Net')
14
  import cv2
15
- import numpy as np
16
  import onnxruntime
17
- import torch
18
- from PIL import Image
19
- sess_options = onnxruntime.SessionOptions()
20
- rmodel = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options)
21
-
22
- # Source https://github.com/advimman/lama
23
- def get_image(image):
24
- if isinstance(image, Image.Image):
25
- img = np.array(image)
26
- elif isinstance(image, np.ndarray):
27
- img = image.copy()
28
- else:
29
- raise Exception("Input image should be either PIL Image or numpy array!")
30
-
31
- if img.ndim == 3:
32
- img = np.transpose(img, (2, 0, 1)) # chw
33
- elif img.ndim == 2:
34
- img = img[np.newaxis, ...]
35
-
36
- assert img.ndim == 3
37
-
38
- img = img.astype(np.float32) / 255
39
- return img
40
 
 
 
 
 
 
41
 
42
- def ceil_modulo(x, mod):
43
- if x % mod == 0:
44
- return x
45
- return (x // mod + 1) * mod
46
 
 
 
47
 
48
- def scale_image(img, factor, interpolation=cv2.INTER_AREA):
49
- if img.shape[0] == 1:
50
- img = img[0]
51
- else:
52
- img = np.transpose(img, (1, 2, 0))
53
 
54
- img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
55
-
56
- if img.ndim == 2:
57
- img = img[None, ...]
 
 
 
58
  else:
59
- img = np.transpose(img, (2, 0, 1))
60
- return img
61
-
62
-
63
- def pad_img_to_modulo(img, mod):
64
- channels, height, width = img.shape
65
- out_height = ceil_modulo(height, mod)
66
- out_width = ceil_modulo(width, mod)
67
- return np.pad(
68
- img,
69
- ((0, 0), (0, out_height - height), (0, out_width - width)),
70
- mode="symmetric",
71
- )
72
-
73
-
74
- def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
75
- out_image = get_image(image)
76
- out_mask = get_image(mask)
77
-
78
- if scale_factor is not None:
79
- out_image = scale_image(out_image, scale_factor)
80
- out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
81
-
82
- if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
83
- out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
84
- out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
85
-
86
- out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
87
- out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
88
-
89
- out_mask = (out_mask > 0) * 1
90
-
91
- return out_image, out_mask
92
-
93
-
94
- def predict(jpg, msk):
95
-
96
-
97
- imagex = Image.open(jpg)
98
- mask = Image.open(msk).convert("L")
99
-
100
- image, mask = prepare_img_and_mask(imagex.resize((512, 512)), mask.resize((512, 512)), 'cpu')
101
- # Run the model
102
- outputs = rmodel.run(None, {'image': image.numpy().astype(np.float32), 'mask': mask.numpy().astype(np.float32)})
103
-
104
  output = outputs[0][0]
105
- # Postprocess the outputs
106
  output = output.transpose(1, 2, 0)
107
- output = output.astype(np.uint8)
108
- output = Image.fromarray(output)
109
- output = output.resize(imagex.size)
110
- output.save("/home/user/app/dataout/data_mask.png")
111
-
112
-
113
- def infer(img,option):
114
- print(type(img))
115
- print(type(img["image"]))
116
- print(type(img["mask"]))
117
- imageio.imwrite("./data/data.png", img["image"])
118
- if option == "automatic (U2net)":
119
- result = model.Segmentation(
120
- images=[cv2.cvtColor(img["image"], cv2.COLOR_RGB2BGR)],
121
- paths=None,
122
- batch_size=1,
123
- input_size=320,
124
- output_dir='output',
125
- visualization=True)
126
- im = Image.fromarray(result[0]['mask'])
127
- im.save("./data/data_mask.png")
128
- else:
129
- imageio.imwrite("./data/data_mask.png", img["mask"])
130
- predict("./data/data.png", "./data/data_mask.png")
131
- return "./dataout/data_mask.png","./data/data_mask.png"
132
-
133
- inputs = [gr.Image(label="Input",type="numpy"),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")]
134
- outputs = [gr.outputs.Image(type="file",label="output"),gr.outputs.Image(type="file",label="Mask")]
135
- title = "LaMa Image Inpainting (using ONNX model from Carve))"
136
- description = "Gradio demo for LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Masks are generated by U^2net"
137
- article = "<p style='text-align: center'><a href='https://huggingface.co/Carve/LaMa-ONNX' target='_blank'>ONNX model ported by Carve.Photos</a> | <a href='https://github.com/saic-mdal/lama' target='_blank'>LaMa github repo</a></p>"
138
- gr.Interface(infer, inputs, outputs, title=title, description=description, article=article).launch()
 
 
1
  import os
 
 
 
 
 
 
 
 
2
  import imageio
3
+ from PIL import Image
4
+ import gradio as gr
 
5
  import cv2
6
+ import paddlehub as hub
7
  import onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Download and setup models
10
+ os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx")
11
+ os.system("pip install onnxruntime imageio")
12
+ os.makedirs("data", exist_ok=True)
13
+ os.makedirs("dataout", exist_ok=True)
14
 
15
+ # Load LaMa ONNX model
16
+ sess_options = onnxruntime.SessionOptions()
17
+ lama_model = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options)
 
18
 
19
+ # Load U^2-Net model for automatic masking
20
+ u2net_model = hub.Module(name='U2Net')
21
 
22
+ # --- Helper Functions ---
 
 
 
 
23
 
24
+ def prepare_image(image, target_size=(512, 512)):
25
+ """Resizes and preprocesses image for LaMa model."""
26
+ if isinstance(image, Image.Image):
27
+ image = image.resize(target_size)
28
+ image = np.array(image)
29
+ elif isinstance(image, np.ndarray):
30
+ image = cv2.resize(image, target_size)
31
  else:
32
+ raise ValueError("Input image should be either PIL Image or numpy array!")
33
+
34
+ # Normalize to [0, 1] and convert to CHW format
35
+ image = image.astype(np.float32) / 255.0
36
+ if image.ndim == 3:
37
+ image = np.transpose(image, (2, 0, 1))
38
+ elif image.ndim == 2:
39
+ image = image[np.newaxis, ...]
40
+ return image[np.newaxis, ...] # Add batch dimension
41
+
42
+ def generate_mask(image, method="automatic"):
43
+ """Generates mask from image using U^2-Net or user input."""
44
+ if method == "automatic":
45
+ input_size = 320 # Adjust based on U^2-Net requirements
46
+ result = u2net_model.Segmentation(
47
+ images=[cv2.cvtColor(image, cv2.COLOR_RGB2BGR)],
48
+ paths=None,
49
+ batch_size=1,
50
+ input_size=input_size,
51
+ output_dir='output',
52
+ visualization=False
53
+ )
54
+ mask = Image.fromarray(result[0]['mask'])
55
+ mask = mask.resize((512, 512)) # Resize to match LaMa input
56
+ mask.save("./data/data_mask.png")
57
+ else: # "manual"
58
+ mask = imageio.imread("./data/data_mask.png")
59
+ mask = Image.fromarray(mask).convert("L") # Ensure grayscale
60
+ mask = mask.resize((512, 512))
61
+ return prepare_image(mask, (512, 512))
62
+
63
+ def inpaint_image(image, mask):
64
+ """Performs inpainting using the LaMa model."""
65
+ outputs = lama_model.run(None, {'image': image, 'mask': mask})
 
 
 
 
 
 
 
 
 
 
 
66
  output = outputs[0][0]
 
67
  output = output.transpose(1, 2, 0)
68
+ output = (output * 255).astype(np.uint8)
69
+ return Image.fromarray(output)
70
+
71
+ # --- Gradio Interface ---
72
+
73
+ def process_image(input_image, mask_option):
74
+ """Main function for Gradio interface."""
75
+ imageio.imwrite("./data/data.png", input_image)
76
+
77
+ image = prepare_image(input_image)
78
+ mask = generate_mask(input_image, method=mask_option)
79
+
80
+ inpainted_image = inpaint_image(image, mask)
81
+ inpainted_image = inpainted_image.resize(Image.open("./data/data.png").size)
82
+ inpainted_image.save("./dataout/data_mask.png")
83
+ return "./dataout/data_mask.png", "./data/data_mask.png"
84
+
85
+ iface = gr.Interface(
86
+ fn=process_image,
87
+ inputs=[
88
+ gr.Image(label="Input Image", type="numpy"),
89
+ gr.Radio(choices=["automatic", "manual"],
90
+ type="value", default="manual", label="Masking Option")
91
+ ],
92
+ outputs=[
93
+ gr.Image(type="file", label="Inpainted Image"),
94
+ gr.Image(type="file", label="Generated Mask")
95
+ ],
96
+ title="LaMa Image Inpainting",
97
+ description="Image inpainting with LaMa and U^2-Net. Upload your image and choose automatic or manual masking.",
98
+ )
99
+
100
+ iface.launch()