MonsterMMORPG commited on
Commit
d4e3209
·
1 Parent(s): ead64bc

Upload webUI_rerender_v1.py

Browse files
Files changed (1) hide show
  1. webUI_rerender_v1.py +970 -0
webUI_rerender_v1.py ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from enum import Enum
4
+
5
+ import cv2
6
+ import einops
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+ from blendmodes.blend import BlendType, blendLayers
13
+ from PIL import Image
14
+ from pytorch_lightning import seed_everything
15
+ from safetensors.torch import load_file
16
+ from skimage import exposure
17
+
18
+ import src.import_util # noqa: F401
19
+ from deps.ControlNet.annotator.canny import CannyDetector
20
+ from deps.ControlNet.annotator.hed import HEDdetector
21
+ from deps.ControlNet.annotator.util import HWC3
22
+ from deps.ControlNet.cldm.model import create_model, load_state_dict
23
+ from deps.gmflow.gmflow.gmflow import GMFlow
24
+ from flow.flow_utils import get_warped_and_mask
25
+ from sd_model_cfg import model_dict
26
+ from src.config import RerenderConfig
27
+ from src.controller import AttentionControl
28
+ from src.ddim_v_hacked import DDIMVSampler
29
+ from src.freeu import freeu_forward
30
+ from src.img_util import find_flat_region, numpy2tensor
31
+ from src.video_util import (frame_to_video, get_fps, get_frame_count,
32
+ prepare_frames)
33
+
34
+ inversed_model_dict = dict()
35
+ for k, v in model_dict.items():
36
+ inversed_model_dict[v] = k
37
+
38
+ to_tensor = T.PILToTensor()
39
+ blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
40
+
41
+
42
+ class ProcessingState(Enum):
43
+ NULL = 0
44
+ FIRST_IMG = 1
45
+ KEY_IMGS = 2
46
+
47
+
48
+ class GlobalState:
49
+
50
+ def __init__(self):
51
+ self.sd_model = None
52
+ self.ddim_v_sampler = None
53
+ self.detector_type = None
54
+ self.detector = None
55
+ self.controller = None
56
+ self.processing_state = ProcessingState.NULL
57
+ flow_model = GMFlow(
58
+ feature_channels=128,
59
+ num_scales=1,
60
+ upsample_factor=8,
61
+ num_head=1,
62
+ attention_type='swin',
63
+ ffn_dim_expansion=4,
64
+ num_transformer_layers=6,
65
+ ).to('cuda')
66
+
67
+ checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
68
+ map_location=lambda storage, loc: storage)
69
+ weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
70
+ flow_model.load_state_dict(weights, strict=False)
71
+ flow_model.eval()
72
+ self.flow_model = flow_model
73
+
74
+ def update_controller(self, inner_strength, mask_period, cross_period,
75
+ ada_period, warp_period, loose_cfattn):
76
+ self.controller = AttentionControl(inner_strength,
77
+ mask_period,
78
+ cross_period,
79
+ ada_period,
80
+ warp_period,
81
+ loose_cfatnn=loose_cfattn)
82
+
83
+ def update_sd_model(self, sd_model, control_type, freeu_args):
84
+ if sd_model == self.sd_model:
85
+ return
86
+ self.sd_model = sd_model
87
+ model = create_model('./deps/ControlNet/models/cldm_v15.yaml').cpu()
88
+ if control_type == 'HED':
89
+ model.load_state_dict(
90
+ load_state_dict('./models/control_sd15_hed.pth',
91
+ location='cuda'))
92
+ elif control_type == 'canny':
93
+ model.load_state_dict(
94
+ load_state_dict('./models/control_sd15_canny.pth',
95
+ location='cuda'))
96
+ model = model.cuda()
97
+ sd_model_path = model_dict[sd_model]
98
+ if len(sd_model_path) > 0:
99
+ model_ext = os.path.splitext(sd_model_path)[1]
100
+ if model_ext == '.safetensors':
101
+ model.load_state_dict(load_file(sd_model_path), strict=False)
102
+ elif model_ext == '.ckpt' or model_ext == '.pth':
103
+ model.load_state_dict(torch.load(sd_model_path)['state_dict'],
104
+ strict=False)
105
+
106
+ try:
107
+ model.first_stage_model.load_state_dict(torch.load(
108
+ './models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'],
109
+ strict=False)
110
+ except Exception:
111
+ print('Warning: We suggest you download the fine-tuned VAE',
112
+ 'otherwise the generation quality will be degraded')
113
+
114
+ model.model.diffusion_model.forward = freeu_forward(
115
+ model.model.diffusion_model, *freeu_args)
116
+ self.ddim_v_sampler = DDIMVSampler(model)
117
+
118
+ def clear_sd_model(self):
119
+ self.sd_model = None
120
+ self.ddim_v_sampler = None
121
+ torch.cuda.empty_cache()
122
+
123
+ def update_detector(self, control_type, canny_low=100, canny_high=200):
124
+ if self.detector_type == control_type:
125
+ return
126
+ if control_type == 'HED':
127
+ self.detector = HEDdetector()
128
+ elif control_type == 'canny':
129
+ canny_detector = CannyDetector()
130
+ low_threshold = canny_low
131
+ high_threshold = canny_high
132
+
133
+ def apply_canny(x):
134
+ return canny_detector(x, low_threshold, high_threshold)
135
+
136
+ self.detector = apply_canny
137
+
138
+
139
+ global_state = GlobalState()
140
+ global_video_path = None
141
+ video_frame_count = None
142
+
143
+
144
+ def create_cfg(input_path, prompt, image_resolution, control_strength,
145
+ color_preserve, left_crop, right_crop, top_crop, bottom_crop,
146
+ control_type, low_threshold, high_threshold, ddim_steps, scale,
147
+ seed, sd_model, a_prompt, n_prompt, interval, keyframe_count,
148
+ x0_strength, use_constraints, cross_start, cross_end,
149
+ style_update_freq, warp_start, warp_end, mask_start, mask_end,
150
+ ada_start, ada_end, mask_strength, inner_strength,
151
+ smooth_boundary, loose_cfattn, b1, b2, s1, s2):
152
+ use_warp = 'shape-aware fusion' in use_constraints
153
+ use_mask = 'pixel-aware fusion' in use_constraints
154
+ use_ada = 'color-aware AdaIN' in use_constraints
155
+
156
+ if not use_warp:
157
+ warp_start = 1
158
+ warp_end = 0
159
+
160
+ if not use_mask:
161
+ mask_start = 1
162
+ mask_end = 0
163
+
164
+ if not use_ada:
165
+ ada_start = 1
166
+ ada_end = 0
167
+
168
+ input_name = os.path.split(input_path)[-1].split('.')[0]
169
+ frame_count = 2 + keyframe_count * interval
170
+ cfg = RerenderConfig()
171
+ cfg.create_from_parameters(
172
+ input_path,
173
+ os.path.join('result', input_name, 'blend.mp4'),
174
+ prompt,
175
+ a_prompt=a_prompt,
176
+ n_prompt=n_prompt,
177
+ frame_count=frame_count,
178
+ interval=interval,
179
+ crop=[left_crop, right_crop, top_crop, bottom_crop],
180
+ sd_model=sd_model,
181
+ ddim_steps=ddim_steps,
182
+ scale=scale,
183
+ control_type=control_type,
184
+ control_strength=control_strength,
185
+ canny_low=low_threshold,
186
+ canny_high=high_threshold,
187
+ seed=seed,
188
+ image_resolution=image_resolution,
189
+ x0_strength=x0_strength,
190
+ style_update_freq=style_update_freq,
191
+ cross_period=(cross_start, cross_end),
192
+ warp_period=(warp_start, warp_end),
193
+ mask_period=(mask_start, mask_end),
194
+ ada_period=(ada_start, ada_end),
195
+ mask_strength=mask_strength,
196
+ inner_strength=inner_strength,
197
+ smooth_boundary=smooth_boundary,
198
+ color_preserve=color_preserve,
199
+ loose_cfattn=loose_cfattn,
200
+ freeu_args=[b1, b2, s1, s2])
201
+ return cfg
202
+
203
+
204
+ def cfg_to_input(filename):
205
+
206
+ cfg = RerenderConfig()
207
+ cfg.create_from_path(filename)
208
+ keyframe_count = (cfg.frame_count - 2) // cfg.interval
209
+ use_constraints = [
210
+ 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN'
211
+ ]
212
+
213
+ sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5')
214
+
215
+ args = [
216
+ cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength,
217
+ cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low,
218
+ cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model,
219
+ cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count,
220
+ cfg.x0_strength, use_constraints, *cfg.cross_period,
221
+ cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period,
222
+ *cfg.ada_period, cfg.mask_strength, cfg.inner_strength,
223
+ cfg.smooth_boundary, cfg.loose_cfattn, *cfg.freeu_args
224
+ ]
225
+ return args
226
+
227
+
228
+ def setup_color_correction(image):
229
+ correction_target = cv2.cvtColor(np.asarray(image.copy()),
230
+ cv2.COLOR_RGB2LAB)
231
+ return correction_target
232
+
233
+
234
+ def apply_color_correction(correction, original_image):
235
+ image = Image.fromarray(
236
+ cv2.cvtColor(
237
+ exposure.match_histograms(cv2.cvtColor(np.asarray(original_image),
238
+ cv2.COLOR_RGB2LAB),
239
+ correction,
240
+ channel_axis=2),
241
+ cv2.COLOR_LAB2RGB).astype('uint8'))
242
+
243
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
244
+
245
+ return image
246
+
247
+
248
+ @torch.no_grad()
249
+ def process(*args):
250
+ args_wo_process3 = args[:-2]
251
+ first_frame = process1(*args_wo_process3)
252
+
253
+ keypath = process2(*args_wo_process3)
254
+
255
+ fullpath = process3(*args)
256
+
257
+ return first_frame, keypath, fullpath
258
+
259
+
260
+ @torch.no_grad()
261
+ def process1(*args):
262
+
263
+ global global_video_path
264
+ cfg = create_cfg(global_video_path, *args)
265
+ global global_state
266
+ global_state.update_sd_model(cfg.sd_model, cfg.control_type,
267
+ cfg.freeu_args)
268
+ global_state.update_controller(cfg.inner_strength, cfg.mask_period,
269
+ cfg.cross_period, cfg.ada_period,
270
+ cfg.warp_period, cfg.loose_cfattn)
271
+ global_state.update_detector(cfg.control_type, cfg.canny_low,
272
+ cfg.canny_high)
273
+ global_state.processing_state = ProcessingState.FIRST_IMG
274
+
275
+ prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution,
276
+ cfg.crop)
277
+
278
+ ddim_v_sampler = global_state.ddim_v_sampler
279
+ model = ddim_v_sampler.model
280
+ detector = global_state.detector
281
+ controller = global_state.controller
282
+ model.control_scales = [cfg.control_strength] * 13
283
+
284
+ num_samples = 1
285
+ eta = 0.0
286
+ imgs = sorted(os.listdir(cfg.input_dir))
287
+ imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
288
+
289
+ with torch.no_grad():
290
+ frame = cv2.imread(imgs[0])
291
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
292
+ img = HWC3(frame)
293
+ H, W, C = img.shape
294
+
295
+ img_ = numpy2tensor(img)
296
+
297
+ def generate_first_img(img_, strength):
298
+ encoder_posterior = model.encode_first_stage(img_.cuda())
299
+ x0 = model.get_first_stage_encoding(encoder_posterior).detach()
300
+
301
+ detected_map = detector(img)
302
+ detected_map = HWC3(detected_map)
303
+
304
+ control = torch.from_numpy(
305
+ detected_map.copy()).float().cuda() / 255.0
306
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
307
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
308
+ cond = {
309
+ 'c_concat': [control],
310
+ 'c_crossattn': [
311
+ model.get_learned_conditioning(
312
+ [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
313
+ ]
314
+ }
315
+ un_cond = {
316
+ 'c_concat': [control],
317
+ 'c_crossattn':
318
+ [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
319
+ }
320
+ shape = (4, H // 8, W // 8)
321
+
322
+ controller.set_task('initfirst')
323
+ seed_everything(cfg.seed)
324
+
325
+ samples, _ = ddim_v_sampler.sample(
326
+ cfg.ddim_steps,
327
+ num_samples,
328
+ shape,
329
+ cond,
330
+ verbose=False,
331
+ eta=eta,
332
+ unconditional_guidance_scale=cfg.scale,
333
+ unconditional_conditioning=un_cond,
334
+ controller=controller,
335
+ x0=x0,
336
+ strength=strength)
337
+ x_samples = model.decode_first_stage(samples)
338
+ x_samples_np = (
339
+ einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
340
+ 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
341
+ return x_samples, x_samples_np
342
+
343
+ # When not preserve color, draw a different frame at first and use its
344
+ # color to redraw the first frame.
345
+ if not cfg.color_preserve:
346
+ first_strength = -1
347
+ else:
348
+ first_strength = 1 - cfg.x0_strength
349
+
350
+ x_samples, x_samples_np = generate_first_img(img_, first_strength)
351
+
352
+ if not cfg.color_preserve:
353
+ color_corrections = setup_color_correction(
354
+ Image.fromarray(x_samples_np[0]))
355
+ global_state.color_corrections = color_corrections
356
+ img_ = apply_color_correction(color_corrections,
357
+ Image.fromarray(img))
358
+ img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
359
+ x_samples, x_samples_np = generate_first_img(
360
+ img_, 1 - cfg.x0_strength)
361
+
362
+ global_state.first_result = x_samples
363
+ global_state.first_img = img
364
+
365
+ Image.fromarray(x_samples_np[0]).save(
366
+ os.path.join(cfg.first_dir, 'first.jpg'))
367
+
368
+ return x_samples_np[0]
369
+
370
+
371
+ @torch.no_grad()
372
+ def process2(*args):
373
+ global global_state
374
+ global global_video_path
375
+
376
+ if global_state.processing_state != ProcessingState.FIRST_IMG:
377
+ raise gr.Error('Please generate the first key image before generating'
378
+ ' all key images')
379
+
380
+ cfg = create_cfg(global_video_path, *args)
381
+ global_state.update_sd_model(cfg.sd_model, cfg.control_type,
382
+ cfg.freeu_args)
383
+ global_state.update_detector(cfg.control_type, cfg.canny_low,
384
+ cfg.canny_high)
385
+ global_state.processing_state = ProcessingState.KEY_IMGS
386
+
387
+ # reset key dir
388
+ shutil.rmtree(cfg.key_dir)
389
+ os.makedirs(cfg.key_dir, exist_ok=True)
390
+
391
+ ddim_v_sampler = global_state.ddim_v_sampler
392
+ model = ddim_v_sampler.model
393
+ detector = global_state.detector
394
+ controller = global_state.controller
395
+ flow_model = global_state.flow_model
396
+ model.control_scales = [cfg.control_strength] * 13
397
+
398
+ num_samples = 1
399
+ eta = 0.0
400
+ firstx0 = True
401
+ pixelfusion = cfg.use_mask
402
+ imgs = sorted(os.listdir(cfg.input_dir))
403
+ imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
404
+
405
+ first_result = global_state.first_result
406
+ first_img = global_state.first_img
407
+ pre_result = first_result
408
+ pre_img = first_img
409
+
410
+ for i in range(0, min(len(imgs), cfg.frame_count) - 1, cfg.interval):
411
+ cid = i + 1
412
+ print(cid)
413
+ if cid <= (len(imgs) - 1):
414
+ frame = cv2.imread(imgs[cid])
415
+ else:
416
+ frame = cv2.imread(imgs[len(imgs) - 1])
417
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
418
+ img = HWC3(frame)
419
+ H, W, C = img.shape
420
+
421
+ if cfg.color_preserve or global_state.color_corrections is None:
422
+ img_ = numpy2tensor(img)
423
+ else:
424
+ img_ = apply_color_correction(global_state.color_corrections,
425
+ Image.fromarray(img))
426
+ img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
427
+ encoder_posterior = model.encode_first_stage(img_.cuda())
428
+ x0 = model.get_first_stage_encoding(encoder_posterior).detach()
429
+
430
+ detected_map = detector(img)
431
+ detected_map = HWC3(detected_map)
432
+
433
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
434
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
435
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
436
+ cond = {
437
+ 'c_concat': [control],
438
+ 'c_crossattn': [
439
+ model.get_learned_conditioning(
440
+ [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
441
+ ]
442
+ }
443
+ un_cond = {
444
+ 'c_concat': [control],
445
+ 'c_crossattn':
446
+ [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
447
+ }
448
+ shape = (4, H // 8, W // 8)
449
+
450
+ cond['c_concat'] = [control]
451
+ un_cond['c_concat'] = [control]
452
+
453
+ image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float()
454
+ image2 = torch.from_numpy(img).permute(2, 0, 1).float()
455
+ warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
456
+ flow_model, image1, image2, pre_result, False)
457
+ blend_mask_pre = blur(
458
+ F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
459
+ blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
460
+
461
+ image1 = torch.from_numpy(first_img).permute(2, 0, 1).float()
462
+ warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
463
+ flow_model, image1, image2, first_result, False)
464
+ blend_mask_0 = blur(
465
+ F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
466
+ blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
467
+
468
+ if firstx0:
469
+ mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)
470
+ controller.set_warp(
471
+ F.interpolate(bwd_flow_0 / 8.0,
472
+ scale_factor=1. / 8,
473
+ mode='bilinear'), mask)
474
+ else:
475
+ mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8)
476
+ controller.set_warp(
477
+ F.interpolate(bwd_flow_pre / 8.0,
478
+ scale_factor=1. / 8,
479
+ mode='bilinear'), mask)
480
+
481
+ controller.set_task('keepx0, keepstyle')
482
+ seed_everything(cfg.seed)
483
+ samples, intermediates = ddim_v_sampler.sample(
484
+ cfg.ddim_steps,
485
+ num_samples,
486
+ shape,
487
+ cond,
488
+ verbose=False,
489
+ eta=eta,
490
+ unconditional_guidance_scale=cfg.scale,
491
+ unconditional_conditioning=un_cond,
492
+ controller=controller,
493
+ x0=x0,
494
+ strength=1 - cfg.x0_strength)
495
+ direct_result = model.decode_first_stage(samples)
496
+
497
+ if not pixelfusion:
498
+ pre_result = direct_result
499
+ pre_img = img
500
+ viz = (
501
+ einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 +
502
+ 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
503
+
504
+ else:
505
+
506
+ blend_results = (1 - blend_mask_pre
507
+ ) * warped_pre + blend_mask_pre * direct_result
508
+ blend_results = (
509
+ 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results
510
+
511
+ bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)
512
+ blend_mask = blur(
513
+ F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))
514
+ blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)
515
+
516
+ encoder_posterior = model.encode_first_stage(blend_results)
517
+ xtrg = model.get_first_stage_encoding(
518
+ encoder_posterior).detach() # * mask
519
+ blend_results_rec = model.decode_first_stage(xtrg)
520
+ encoder_posterior = model.encode_first_stage(blend_results_rec)
521
+ xtrg_rec = model.get_first_stage_encoding(
522
+ encoder_posterior).detach()
523
+ xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) # * mask
524
+ blend_results_rec_new = model.decode_first_stage(xtrg_)
525
+ tmp = (abs(blend_results_rec_new - blend_results).mean(
526
+ dim=1, keepdims=True) > 0.25).float()
527
+ mask_x = F.max_pool2d((F.interpolate(
528
+ tmp, scale_factor=1 / 8., mode='bilinear') > 0).float(),
529
+ kernel_size=3,
530
+ stride=1,
531
+ padding=1)
532
+
533
+ mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8)
534
+ ) # * (1-mask_x)
535
+
536
+ if cfg.smooth_boundary:
537
+ noise_rescale = find_flat_region(mask)
538
+ else:
539
+ noise_rescale = torch.ones_like(mask)
540
+ masks = []
541
+ for i in range(cfg.ddim_steps):
542
+ if i <= cfg.ddim_steps * cfg.mask_period[
543
+ 0] or i >= cfg.ddim_steps * cfg.mask_period[1]:
544
+ masks += [None]
545
+ else:
546
+ masks += [mask * cfg.mask_strength]
547
+
548
+ # mask 3
549
+ # xtrg = ((1-mask_x) *
550
+ # (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask
551
+ # mask 2
552
+ # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask
553
+ xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask # mask 1
554
+
555
+ tasks = 'keepstyle, keepx0'
556
+ if not firstx0:
557
+ tasks += ', updatex0'
558
+ if i % cfg.style_update_freq == 0:
559
+ tasks += ', updatestyle'
560
+ controller.set_task(tasks, 1.0)
561
+
562
+ seed_everything(cfg.seed)
563
+ samples, _ = ddim_v_sampler.sample(
564
+ cfg.ddim_steps,
565
+ num_samples,
566
+ shape,
567
+ cond,
568
+ verbose=False,
569
+ eta=eta,
570
+ unconditional_guidance_scale=cfg.scale,
571
+ unconditional_conditioning=un_cond,
572
+ controller=controller,
573
+ x0=x0,
574
+ strength=1 - cfg.x0_strength,
575
+ xtrg=xtrg,
576
+ mask=masks,
577
+ noise_rescale=noise_rescale)
578
+ x_samples = model.decode_first_stage(samples)
579
+ pre_result = x_samples
580
+ pre_img = img
581
+
582
+ viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
583
+ 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
584
+
585
+ Image.fromarray(viz[0]).save(
586
+ os.path.join(cfg.key_dir, f'{cid:04d}.png'))
587
+
588
+ key_video_path = os.path.join(cfg.work_dir, 'key.mp4')
589
+ fps = get_fps(cfg.input_path)
590
+ fps //= cfg.interval
591
+ frame_to_video(key_video_path, cfg.key_dir, fps, False)
592
+
593
+ return key_video_path
594
+
595
+
596
+ @torch.no_grad()
597
+ def process3(*args):
598
+ max_process = args[-2]
599
+ use_poisson = args[-1]
600
+ args = args[:-2]
601
+ global global_video_path
602
+ global global_state
603
+ if global_state.processing_state != ProcessingState.KEY_IMGS:
604
+ raise gr.Error('Please generate key images before propagation')
605
+
606
+ global_state.clear_sd_model()
607
+
608
+ cfg = create_cfg(global_video_path, *args)
609
+
610
+ # reset blend dir
611
+ blend_dir = os.path.join(cfg.work_dir, 'blend')
612
+ if os.path.exists(blend_dir):
613
+ shutil.rmtree(blend_dir)
614
+ os.makedirs(blend_dir, exist_ok=True)
615
+
616
+ video_base_dir = cfg.work_dir
617
+ o_video = cfg.output_path
618
+ fps = get_fps(cfg.input_path)
619
+
620
+ end_frame = cfg.frame_count - 1
621
+ interval = cfg.interval
622
+ key_dir = os.path.split(cfg.key_dir)[-1]
623
+ o_video_cmd = f'--output {o_video}'
624
+ ps = '-ps' if use_poisson else ''
625
+ cmd = (f'python video_blend.py {video_base_dir} --beg 1 --end {end_frame} '
626
+ f'--itv {interval} --key {key_dir} {o_video_cmd} --fps {fps} '
627
+ f'--n_proc {max_process} {ps}')
628
+ print(cmd)
629
+ os.system(cmd)
630
+
631
+ return o_video
632
+
633
+
634
+ block = gr.Blocks().queue()
635
+ with block:
636
+ with gr.Row():
637
+ gr.Markdown('## Rerender A Video')
638
+ with gr.Row():
639
+ with gr.Column():
640
+ input_path = gr.Video(label='Input Video',
641
+ source='upload',
642
+ format='mp4',
643
+ visible=True)
644
+ prompt = gr.Textbox(label='Prompt')
645
+ seed = gr.Slider(label='Seed',
646
+ minimum=0,
647
+ maximum=2147483647,
648
+ step=1,
649
+ value=0,
650
+ randomize=True)
651
+ run_button = gr.Button(value='Run All')
652
+ with gr.Row():
653
+ run_button1 = gr.Button(value='Run 1st Key Frame')
654
+ run_button2 = gr.Button(value='Run Key Frames')
655
+ run_button3 = gr.Button(value='Run Propagation')
656
+ with gr.Accordion('Advanced options for the 1st frame translation',
657
+ open=False):
658
+ image_resolution = gr.Slider(label='Frame resolution',
659
+ minimum=256,
660
+ maximum=768,
661
+ value=512,
662
+ step=64)
663
+ control_strength = gr.Slider(label='ControlNet strength',
664
+ minimum=0.0,
665
+ maximum=2.0,
666
+ value=1.0,
667
+ step=0.01)
668
+ x0_strength = gr.Slider(
669
+ label='Denoising strength',
670
+ minimum=0.00,
671
+ maximum=1.05,
672
+ value=0.75,
673
+ step=0.05,
674
+ info=('0: fully recover the input.'
675
+ '1.05: fully rerender the input.'))
676
+ color_preserve = gr.Checkbox(
677
+ label='Preserve color',
678
+ value=True,
679
+ info='Keep the color of the input video')
680
+ with gr.Row():
681
+ left_crop = gr.Slider(label='Left crop length',
682
+ minimum=0,
683
+ maximum=512,
684
+ value=0,
685
+ step=1)
686
+ right_crop = gr.Slider(label='Right crop length',
687
+ minimum=0,
688
+ maximum=512,
689
+ value=0,
690
+ step=1)
691
+ with gr.Row():
692
+ top_crop = gr.Slider(label='Top crop length',
693
+ minimum=0,
694
+ maximum=512,
695
+ value=0,
696
+ step=1)
697
+ bottom_crop = gr.Slider(label='Bottom crop length',
698
+ minimum=0,
699
+ maximum=512,
700
+ value=0,
701
+ step=1)
702
+ with gr.Row():
703
+ control_type = gr.Dropdown(['HED', 'canny'],
704
+ label='Control type',
705
+ value='HED')
706
+ low_threshold = gr.Slider(label='Canny low threshold',
707
+ minimum=1,
708
+ maximum=255,
709
+ value=100,
710
+ step=1)
711
+ high_threshold = gr.Slider(label='Canny high threshold',
712
+ minimum=1,
713
+ maximum=255,
714
+ value=200,
715
+ step=1)
716
+ ddim_steps = gr.Slider(label='Steps',
717
+ minimum=20,
718
+ maximum=100,
719
+ value=20,
720
+ step=20)
721
+ scale = gr.Slider(label='CFG scale',
722
+ minimum=0.1,
723
+ maximum=30.0,
724
+ value=7.5,
725
+ step=0.1)
726
+ sd_model_list = list(model_dict.keys())
727
+ sd_model = gr.Dropdown(sd_model_list,
728
+ label='Base model',
729
+ value='Stable Diffusion 1.5')
730
+ a_prompt = gr.Textbox(label='Added prompt',
731
+ value='best quality, extremely detailed')
732
+ n_prompt = gr.Textbox(
733
+ label='Negative prompt',
734
+ value=('longbody, lowres, bad anatomy, bad hands, '
735
+ 'missing fingers, extra digit, fewer digits, '
736
+ 'cropped, worst quality, low quality'))
737
+ with gr.Row():
738
+ b1 = gr.Slider(label='FreeU first-stage backbone factor',
739
+ minimum=1,
740
+ maximum=1.6,
741
+ value=1,
742
+ step=0.01,
743
+ info='FreeU to enhance texture and color')
744
+ b2 = gr.Slider(label='FreeU second-stage backbone factor',
745
+ minimum=1,
746
+ maximum=1.6,
747
+ value=1,
748
+ step=0.01)
749
+ with gr.Row():
750
+ s1 = gr.Slider(label='FreeU first-stage skip factor',
751
+ minimum=0,
752
+ maximum=1,
753
+ value=1,
754
+ step=0.01)
755
+ s2 = gr.Slider(label='FreeU second-stage skip factor',
756
+ minimum=0,
757
+ maximum=1,
758
+ value=1,
759
+ step=0.01)
760
+ with gr.Accordion('Advanced options for the key fame translation',
761
+ open=False):
762
+ interval = gr.Slider(
763
+ label='Key frame frequency (K)',
764
+ minimum=1,
765
+ maximum=1,
766
+ value=1,
767
+ step=1,
768
+ info='Uniformly sample the key frames every K frames')
769
+ keyframe_count = gr.Slider(label='Number of key frames',
770
+ minimum=1,
771
+ maximum=1,
772
+ value=1,
773
+ step=1)
774
+
775
+ use_constraints = gr.CheckboxGroup(
776
+ [
777
+ 'shape-aware fusion', 'pixel-aware fusion',
778
+ 'color-aware AdaIN'
779
+ ],
780
+ label='Select the cross-frame contraints to be used',
781
+ value=[
782
+ 'shape-aware fusion', 'pixel-aware fusion',
783
+ 'color-aware AdaIN'
784
+ ]),
785
+ with gr.Row():
786
+ cross_start = gr.Slider(
787
+ label='Cross-frame attention start',
788
+ minimum=0,
789
+ maximum=1,
790
+ value=0,
791
+ step=0.05)
792
+ cross_end = gr.Slider(label='Cross-frame attention end',
793
+ minimum=0,
794
+ maximum=1,
795
+ value=1,
796
+ step=0.05)
797
+ style_update_freq = gr.Slider(
798
+ label='Cross-frame attention update frequency',
799
+ minimum=1,
800
+ maximum=100,
801
+ value=1,
802
+ step=1,
803
+ info=('Update the key and value for '
804
+ 'cross-frame attention every N key frames'))
805
+ loose_cfattn = gr.Checkbox(
806
+ label='Loose Cross-frame attention',
807
+ value=True,
808
+ info='Select to make output better match the input video')
809
+ with gr.Row():
810
+ warp_start = gr.Slider(label='Shape-aware fusion start',
811
+ minimum=0,
812
+ maximum=1,
813
+ value=0,
814
+ step=0.05)
815
+ warp_end = gr.Slider(label='Shape-aware fusion end',
816
+ minimum=0,
817
+ maximum=1,
818
+ value=0.1,
819
+ step=0.05)
820
+ with gr.Row():
821
+ mask_start = gr.Slider(label='Pixel-aware fusion start',
822
+ minimum=0,
823
+ maximum=1,
824
+ value=0.5,
825
+ step=0.05)
826
+ mask_end = gr.Slider(label='Pixel-aware fusion end',
827
+ minimum=0,
828
+ maximum=1,
829
+ value=0.8,
830
+ step=0.05)
831
+ with gr.Row():
832
+ ada_start = gr.Slider(label='Color-aware AdaIN start',
833
+ minimum=0,
834
+ maximum=1,
835
+ value=0.8,
836
+ step=0.05)
837
+ ada_end = gr.Slider(label='Color-aware AdaIN end',
838
+ minimum=0,
839
+ maximum=1,
840
+ value=1,
841
+ step=0.05)
842
+ mask_strength = gr.Slider(label='Pixel-aware fusion strength',
843
+ minimum=0,
844
+ maximum=1,
845
+ value=0.5,
846
+ step=0.01)
847
+ inner_strength = gr.Slider(
848
+ label='Pixel-aware fusion detail level',
849
+ minimum=0.5,
850
+ maximum=1,
851
+ value=0.9,
852
+ step=0.01,
853
+ info='Use a low value to prevent artifacts')
854
+ smooth_boundary = gr.Checkbox(
855
+ label='Smooth fusion boundary',
856
+ value=True,
857
+ info='Select to prevent artifacts at boundary')
858
+ with gr.Accordion(
859
+ 'Advanced options for the full video translation',
860
+ open=False):
861
+ use_poisson = gr.Checkbox(
862
+ label='Gradient blending',
863
+ value=True,
864
+ info=('Blend the output video in gradient, to reduce'
865
+ ' ghosting artifacts (but may increase flickers)'))
866
+ max_process = gr.Slider(label='Number of parallel processes',
867
+ minimum=1,
868
+ maximum=16,
869
+ value=4,
870
+ step=1)
871
+
872
+ with gr.Accordion('Example configs', open=True):
873
+ config_dir = 'config'
874
+ config_list = [
875
+ 'real2sculpture.json', 'van_gogh_man.json', 'woman.json'
876
+ ]
877
+ args_list = []
878
+ for config in config_list:
879
+ try:
880
+ config_path = os.path.join(config_dir, config)
881
+ args = cfg_to_input(config_path)
882
+ args_list.append(args)
883
+ except FileNotFoundError:
884
+ # The video file does not exist, skipped
885
+ pass
886
+
887
+ ips = [
888
+ prompt, image_resolution, control_strength, color_preserve,
889
+ left_crop, right_crop, top_crop, bottom_crop, control_type,
890
+ low_threshold, high_threshold, ddim_steps, scale, seed,
891
+ sd_model, a_prompt, n_prompt, interval, keyframe_count,
892
+ x0_strength, use_constraints[0], cross_start, cross_end,
893
+ style_update_freq, warp_start, warp_end, mask_start,
894
+ mask_end, ada_start, ada_end, mask_strength,
895
+ inner_strength, smooth_boundary, loose_cfattn, b1, b2, s1,
896
+ s2
897
+ ]
898
+
899
+ gr.Examples(
900
+ examples=args_list,
901
+ inputs=[input_path, *ips],
902
+ )
903
+
904
+ with gr.Column():
905
+ result_image = gr.Image(label='Output first frame',
906
+ type='numpy',
907
+ interactive=False)
908
+ result_keyframe = gr.Video(label='Output key frame video',
909
+ format='mp4',
910
+ interactive=False)
911
+ result_video = gr.Video(label='Output full video',
912
+ format='mp4',
913
+ interactive=False)
914
+
915
+ def input_uploaded(path):
916
+ frame_count = get_frame_count(path)
917
+ if frame_count <= 2:
918
+ raise gr.Error('The input video is too short!'
919
+ 'Please input another video.')
920
+
921
+ default_interval = min(10, frame_count - 2)
922
+ max_keyframe = (frame_count - 2) // default_interval
923
+
924
+ global video_frame_count
925
+ video_frame_count = frame_count
926
+ global global_video_path
927
+ global_video_path = path
928
+
929
+ return gr.Slider.update(value=default_interval,
930
+ maximum=max_keyframe), gr.Slider.update(
931
+ value=max_keyframe, maximum=max_keyframe)
932
+
933
+ def input_changed(path):
934
+ frame_count = get_frame_count(path)
935
+ if frame_count <= 2:
936
+ return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1)
937
+
938
+ default_interval = min(10, frame_count - 2)
939
+ max_keyframe = (frame_count - 2) // default_interval
940
+
941
+ global video_frame_count
942
+ video_frame_count = frame_count
943
+ global global_video_path
944
+ global_video_path = path
945
+
946
+ return gr.Slider.update(maximum=max_keyframe), \
947
+ gr.Slider.update(maximum=max_keyframe)
948
+
949
+ def interval_changed(interval):
950
+ global video_frame_count
951
+ if video_frame_count is None:
952
+ return gr.Slider.update()
953
+
954
+ max_keyframe = (video_frame_count - 2) // interval
955
+
956
+ return gr.Slider.update(value=max_keyframe, maximum=max_keyframe)
957
+
958
+ input_path.change(input_changed, input_path, [interval, keyframe_count])
959
+ input_path.upload(input_uploaded, input_path, [interval, keyframe_count])
960
+ interval.change(interval_changed, interval, keyframe_count)
961
+
962
+ ips_process3 = [*ips, max_process, use_poisson]
963
+ run_button.click(fn=process,
964
+ inputs=ips_process3,
965
+ outputs=[result_image, result_keyframe, result_video])
966
+ run_button1.click(fn=process1, inputs=ips, outputs=[result_image])
967
+ run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
968
+ run_button3.click(fn=process3, inputs=ips_process3, outputs=[result_video])
969
+
970
+ block.queue(concurrency_count=10).launch(share=True)