ai-forever commited on
Commit
b212d2d
·
verified ·
1 Parent(s): 15d76c3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+
4
+ import cv2
5
+ import torch
6
+ import argparse
7
+ import yaml
8
+ from torchvision import transforms
9
+ import onnxruntime as ort
10
+ from PIL import Image
11
+ from insightface.app import FaceAnalysis
12
+ from omegaconf import OmegaConf
13
+ from torchvision.transforms.functional import rgb_to_grayscale
14
+
15
+ from src.utils.crops import *
16
+ from repos.stylematte.stylematte.models import StyleMatte
17
+ from src.utils.inference import *
18
+ from src.utils.inpainter import LamaInpainter
19
+ from src.utils.preblending import calc_pseudo_target_bg
20
+ from train_aligner import AlignerModule
21
+ from train_blender import BlenderModule
22
+
23
+ @spaces.GPU
24
+ def infer_headswap(source, target):
25
+ def calc_mask(img):
26
+ if isinstance(img, np.ndarray):
27
+ img = torch.from_numpy(img).permute(2, 0, 1).cuda()
28
+ if img.max() > 1.:
29
+ img = img / 255.0
30
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
31
+ std=[0.229, 0.224, 0.225])
32
+ input_t = normalize(img)
33
+ input_t = input_t.unsqueeze(0).float()
34
+ with torch.no_grad():
35
+ out = segment_model(input_t)
36
+ result = out[0]
37
+
38
+ return result[0]
39
+
40
+ def process_img(img, target=False):
41
+ full_frames = np.array(img)[:, :, ::-1]
42
+ dets = app.get(full_frames)
43
+ kps = dets[0]['kps']
44
+ wide = wide_crop_face(full_frames, kps, return_M=target)
45
+ if target:
46
+ wide, M = wide
47
+ arc = norm_crop(full_frames, kps)
48
+ mask = calc_mask(wide)
49
+ arc = normalize_and_torch(arc)
50
+ wide = normalize_and_torch(wide)
51
+ if target:
52
+ return wide, arc, mask, full_frames, M
53
+ return wide, arc, mask
54
+
55
+ wide_source, arc_source, mask_source = process_img(source)
56
+ wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True)
57
+
58
+
59
+ wide_source = wide_source.unsqueeze(1)
60
+ arc_source = arc_source.unsqueeze(1)
61
+ source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0)
62
+ target_mask = mask_target.unsqueeze(0).unsqueeze(0)
63
+
64
+ X_dict = {
65
+ 'source': {
66
+ 'face_arc': arc_source,
67
+ 'face_wide': wide_source * mask_source,
68
+ 'face_wide_mask': mask_source
69
+ },
70
+ 'target': {
71
+ 'face_arc': arc_target,
72
+ 'face_wide': wide_target * mask_target,
73
+ 'face_wide_mask': mask_target
74
+ }
75
+ }
76
+
77
+ with torch.no_grad():
78
+ output = aligner(X_dict)
79
+
80
+
81
+ target_parsing = infer_parsing(wide_target)
82
+ pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing)
83
+ soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None]
84
+ new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...])
85
+
86
+ blender_input = {
87
+ 'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source,
88
+ 'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0),
89
+ 'face_target': wide_target,
90
+ 'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']),
91
+ 'mask_target': target_parsing,
92
+ 'mask_source_noise': None,
93
+ 'mask_target_noise': None,
94
+ 'alpha_source': soft_mask
95
+ }
96
+
97
+ output_b = blender(blender_input, inpainter=inpainter)
98
+
99
+ np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255)
100
+ result = copy_head_back(np_output, full_frame[..., ::-1], M)
101
+ return Image.fromarray(result)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+
107
+ # Generator params
108
+ parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config')
109
+ parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config')
110
+ parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image')
111
+ parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image')
112
+ parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint')
113
+ parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint')
114
+ parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result')
115
+
116
+ args = parser.parse_args()
117
+
118
+ with open(args.config_a, "r") as stream:
119
+ cfg_a = OmegaConf.load(stream)
120
+
121
+ with open(args.config_b, "r") as stream:
122
+ cfg_b = OmegaConf.load(stream)
123
+
124
+ aligner = AlignerModule(cfg_a)
125
+ ckpt = torch.load(args.ckpt_a, map_location='cpu')
126
+ aligner.load_state_dict(torch.load(args.ckpt_a), strict=False)
127
+ aligner.eval()
128
+ aligner.cuda()
129
+
130
+ blender = BlenderModule(cfg_b)
131
+ blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,)
132
+ blender.eval()
133
+ blender.cuda()
134
+
135
+ inpainter = LamaInpainter()
136
+
137
+ app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection'])
138
+ app.prepare(ctx_id=0, det_size=(640, 640))
139
+
140
+ segment_model = StyleMatte()
141
+ segment_model.load_state_dict(
142
+ torch.load(
143
+ './repos/stylematte/stylematte/checkpoints/stylematte_synth.pth',
144
+ map_location='cpu'
145
+ )
146
+ )
147
+ segment_model = segment_model.cuda()
148
+ segment_model.eval()
149
+
150
+ providers = [
151
+ ("CUDAExecutionProvider", {})
152
+ ]
153
+ parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers)
154
+ input_name = parsings_session.get_inputs()[0].name
155
+ output_names = [output.name for output in parsings_session.get_outputs()]
156
+
157
+ mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None]
158
+ std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None]
159
+
160
+ infer_parsing = lambda img: torch.tensor(
161
+ parsings_session.run(output_names, {
162
+ input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32)
163
+ })[0],
164
+ device='cuda',
165
+ dtype=torch.float32
166
+ )
167
+
168
+ source_pil = Image.open(args.source)
169
+ target_pil = Image.open(args.target)
170
+
171
+ with gr.Blocks(css=css) as demo:
172
+ with gr.Column():
173
+ # gr.HTML(title)
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ input_source = gr.Image(
178
+ type="pil",
179
+ label="Input Source"
180
+ )
181
+ input_target = gr.Image(
182
+ type="pil",
183
+ label="Input Target"
184
+ )
185
+ run_button = gr.Button("Generate")
186
+
187
+ # with gr.Row():
188
+ # with gr.Column(scale=2):
189
+ # prompt_input = gr.Textbox(label="Prompt (Optional)")
190
+ # with gr.Column(scale=1):
191
+ # run_button = gr.Button("Generate")
192
+
193
+ # with gr.Row():
194
+ # target_ratio = gr.Radio(
195
+ # label="Expected Ratio",
196
+ # choices=["9:16", "16:9", "1:1", "Custom"],
197
+ # value="9:16",
198
+ # scale=2
199
+ # )
200
+
201
+ # alignment_dropdown = gr.Dropdown(
202
+ # choices=["Middle", "Left", "Right", "Top", "Bottom"],
203
+ # value="Middle",
204
+ # label="Alignment"
205
+ # )
206
+
207
+ # with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
208
+ # with gr.Column():
209
+ # with gr.Row():
210
+ # width_slider = gr.Slider(
211
+ # label="Target Width",
212
+ # minimum=720,
213
+ # maximum=1536,
214
+ # step=8,
215
+ # value=720, # Set a default value
216
+ # )
217
+ # height_slider = gr.Slider(
218
+ # label="Target Height",
219
+ # minimum=720,
220
+ # maximum=1536,
221
+ # step=8,
222
+ # value=1280, # Set a default value
223
+ # )
224
+
225
+ # num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
226
+ # with gr.Group():
227
+ # overlap_percentage = gr.Slider(
228
+ # label="Mask overlap (%)",
229
+ # minimum=1,
230
+ # maximum=50,
231
+ # value=10,
232
+ # step=1
233
+ # )
234
+ # with gr.Row():
235
+ # overlap_top = gr.Checkbox(label="Overlap Top", value=True)
236
+ # overlap_right = gr.Checkbox(label="Overlap Right", value=True)
237
+ # with gr.Row():
238
+ # overlap_left = gr.Checkbox(label="Overlap Left", value=True)
239
+ # overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
240
+ # with gr.Row():
241
+ # resize_option = gr.Radio(
242
+ # label="Resize input image",
243
+ # choices=["Full", "50%", "33%", "25%", "Custom"],
244
+ # value="Full"
245
+ # )
246
+ # custom_resize_percentage = gr.Slider(
247
+ # label="Custom resize (%)",
248
+ # minimum=1,
249
+ # maximum=100,
250
+ # step=1,
251
+ # value=50,
252
+ # visible=False
253
+ # )
254
+
255
+ # with gr.Column():
256
+ # preview_button = gr.Button("Preview alignment and mask")
257
+
258
+
259
+ # gr.Examples(
260
+ # examples=[
261
+ # ["./examples/example_1.webp", 1280, 720, "Middle"],
262
+ # ["./examples/example_2.jpg", 1440, 810, "Left"],
263
+ # ["./examples/example_3.jpg", 1024, 1024, "Top"],
264
+ # ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
265
+ # ],
266
+ # inputs=[input_image, width_slider, height_slider, alignment_dropdown],
267
+ # )
268
+
269
+
270
+
271
+ with gr.Column():
272
+ result = ImageSlider(
273
+ interactive=False,
274
+ label="Generated Image",
275
+ )
276
+ # use_as_input_button = gr.Button("Use as Input Image", visible=False)
277
+
278
+ # history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
279
+ # preview_image = gr.Image(label="Preview")
280
+ gr.on(
281
+ trigger=[run_button.click],
282
+ fn=infer_headswap,
283
+ inputs=[input_source, input_target],
284
+ outputs=[result]
285
+ )
286
+
287
+
288
+ demo.launch()