|
import argparse |
|
import os |
|
import platform |
|
import struct |
|
import subprocess |
|
import time |
|
from typing import List |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch.multiprocessing as mp |
|
from numba import njit |
|
|
|
import sys |
|
sys.path.append("./src/ebsynth/") |
|
import blender.histogram_blend as histogram_blend |
|
from blender.guide import (BaseGuide, ColorGuide, EdgeGuide, PositionalGuide, |
|
TemporalGuide) |
|
from blender.poisson_fusion import poisson_fusion |
|
from blender.video_sequence import VideoSequence |
|
from flow.flow_utils import flow_calc |
|
from src.video_util import frame_to_video |
|
|
|
OPEN_EBSYNTH_LOG = False |
|
MAX_PROCESS = 8 |
|
|
|
os_str = platform.system() |
|
|
|
if os_str == 'Windows': |
|
ebsynth_bin = '.\\src\\ebsynth\\deps\\ebsynth\\bin\\ebsynth.exe' |
|
elif os_str == 'Linux': |
|
ebsynth_bin = './src/ebsynth/deps/ebsynth/bin/ebsynth' |
|
elif os_str == 'Darwin': |
|
ebsynth_bin = './src/ebsynth/deps/ebsynth/bin/ebsynth.app' |
|
else: |
|
print('Cannot recognize OS. Run Ebsynth failed.') |
|
exit(0) |
|
|
|
|
|
@njit |
|
def g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2): |
|
for i in range(H): |
|
for j in range(W): |
|
if weight1 * dist1[i, j] < weight2 * dist2[i, j]: |
|
output[i, j] = 0 |
|
else: |
|
output[i, j] = 1 |
|
if weight1 == 0: |
|
output[i, j] = 0 |
|
elif weight2 == 0: |
|
output[i, j] = 1 |
|
|
|
|
|
def g_error_mask(dist1, dist2, weight1=1, weight2=1): |
|
H, W = dist1.shape |
|
output = np.empty_like(dist1, dtype=np.byte) |
|
g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2) |
|
return output |
|
|
|
|
|
def create_sequence(base_dir, key_ind, key_dir): |
|
sequence = VideoSequence(base_dir, key_ind, 'video', key_dir, |
|
'tmp', '%04d.png', '%04d.png') |
|
return sequence |
|
|
|
|
|
def process_one_sequence(i, video_sequence: VideoSequence): |
|
interval = video_sequence.interval(i) |
|
for is_forward in [True, False]: |
|
input_seq = video_sequence.get_input_sequence(i, is_forward) |
|
output_seq = video_sequence.get_output_sequence(i, is_forward) |
|
flow_seq = video_sequence.get_flow_sequence(i, is_forward) |
|
key_img_id = i if is_forward else i + 1 |
|
key_img = video_sequence.get_key_img(key_img_id) |
|
for j in range(interval - 1): |
|
i1 = cv2.imread(input_seq[j]) |
|
i2 = cv2.imread(input_seq[j + 1]) |
|
flow_calc.get_flow(i1, i2, flow_seq[j]) |
|
|
|
guides: List[BaseGuide] = [ |
|
ColorGuide(input_seq), |
|
EdgeGuide(input_seq, |
|
video_sequence.get_edge_sequence(i, is_forward)), |
|
TemporalGuide(key_img, output_seq, flow_seq, |
|
video_sequence.get_temporal_sequence(i, is_forward)), |
|
PositionalGuide(flow_seq, |
|
video_sequence.get_pos_sequence(i, is_forward)) |
|
] |
|
weights = [6, 0.5, 0.5, 2] |
|
for j in range(interval): |
|
|
|
if j == 0: |
|
img = cv2.imread(key_img) |
|
cv2.imwrite(output_seq[0], img) |
|
else: |
|
cmd = f'{ebsynth_bin} -style {os.path.abspath(key_img)}' |
|
for g, w in zip(guides, weights): |
|
cmd += ' ' + g.get_cmd(j, w) |
|
|
|
cmd += (f' -output {os.path.abspath(output_seq[j])}' |
|
' -searchvoteiters 12 -patchmatchiters 6') |
|
if OPEN_EBSYNTH_LOG: |
|
print(cmd) |
|
subprocess.run(cmd, |
|
shell=True, |
|
capture_output=not OPEN_EBSYNTH_LOG) |
|
|
|
|
|
def process_sequences(i_arr, video_sequence: VideoSequence): |
|
for i in i_arr: |
|
process_one_sequence(i, video_sequence) |
|
|
|
|
|
def run_ebsynth(video_sequence: VideoSequence): |
|
|
|
beg = time.time() |
|
|
|
processes = [] |
|
mp.set_start_method('spawn') |
|
|
|
n_process = min(MAX_PROCESS, video_sequence.n_seq) |
|
cnt = video_sequence.n_seq // n_process |
|
remainder = video_sequence.n_seq % n_process |
|
|
|
prev_idx = 0 |
|
|
|
for i in range(n_process): |
|
task_cnt = cnt + 1 if i < remainder else cnt |
|
i_arr = list(range(prev_idx, prev_idx + task_cnt)) |
|
prev_idx += task_cnt |
|
p = mp.Process(target=process_sequences, args=(i_arr, video_sequence)) |
|
p.start() |
|
processes.append(p) |
|
for p in processes: |
|
p.join() |
|
|
|
end = time.time() |
|
|
|
print(f'ebsynth: {end-beg}') |
|
|
|
|
|
@njit |
|
def assemble_min_error_img_loop(H, W, a, b, error_mask, out): |
|
for i in range(H): |
|
for j in range(W): |
|
if error_mask[i, j] == 0: |
|
out[i, j] = a[i, j] |
|
else: |
|
out[i, j] = b[i, j] |
|
|
|
|
|
def assemble_min_error_img(a, b, error_mask): |
|
H, W = a.shape[0:2] |
|
out = np.empty_like(a) |
|
assemble_min_error_img_loop(H, W, a, b, error_mask, out) |
|
return out |
|
|
|
|
|
def load_error(bin_path, img_shape): |
|
img_size = img_shape[0] * img_shape[1] |
|
with open(bin_path, 'rb') as fp: |
|
bytes = fp.read() |
|
|
|
read_size = struct.unpack('q', bytes[:8]) |
|
assert read_size[0] == img_size |
|
float_res = struct.unpack('f' * img_size, bytes[8:]) |
|
res = np.array(float_res, |
|
dtype=np.float32).reshape(img_shape[0], img_shape[1]) |
|
return res |
|
|
|
|
|
def process_seq(video_sequence: VideoSequence, |
|
i, |
|
blend_histogram=True, |
|
blend_gradient=True): |
|
|
|
key1_img = cv2.imread(video_sequence.get_key_img(i)) |
|
img_shape = key1_img.shape |
|
interval = video_sequence.interval(i) |
|
beg_id = video_sequence.get_sequence_beg_id(i) |
|
|
|
oas = video_sequence.get_output_sequence(i) |
|
obs = video_sequence.get_output_sequence(i, False) |
|
|
|
binas = [x.replace('jpg', 'bin') for x in oas] |
|
binbs = [x.replace('jpg', 'bin') for x in obs] |
|
|
|
obs = [obs[0]] + list(reversed(obs[1:])) |
|
inputs = video_sequence.get_input_sequence(i) |
|
oas = [cv2.imread(x) for x in oas] |
|
obs = [cv2.imread(x) for x in obs] |
|
inputs = [cv2.imread(x) for x in inputs] |
|
flow_seq = video_sequence.get_flow_sequence(i) |
|
|
|
dist1s = [] |
|
dist2s = [] |
|
for i in range(interval - 1): |
|
bin_a = binas[i + 1] |
|
bin_b = binbs[i + 1] |
|
dist1s.append(load_error(bin_a, img_shape)) |
|
dist2s.append(load_error(bin_b, img_shape)) |
|
|
|
lb = 0 |
|
ub = 1 |
|
beg = time.time() |
|
p_mask = None |
|
|
|
|
|
blend_out_path = video_sequence.get_blending_img(beg_id) |
|
cv2.imwrite(blend_out_path, key1_img) |
|
|
|
for i in range(interval - 1): |
|
c_id = beg_id + i + 1 |
|
blend_out_path = video_sequence.get_blending_img(c_id) |
|
|
|
dist1 = dist1s[i] |
|
dist2 = dist2s[i] |
|
oa = oas[i + 1] |
|
ob = obs[i + 1] |
|
weight1 = i / (interval - 1) * (ub - lb) + lb |
|
weight2 = 1 - weight1 |
|
mask = g_error_mask(dist1, dist2, weight1, weight2) |
|
if p_mask is not None: |
|
flow_path = flow_seq[i] |
|
flow = flow_calc.get_flow(inputs[i], inputs[i + 1], flow_path) |
|
p_mask = flow_calc.warp(p_mask, flow, 'nearest') |
|
mask = p_mask | mask |
|
p_mask = mask |
|
|
|
|
|
|
|
|
|
|
|
min_error_img = assemble_min_error_img(oa, ob, mask) |
|
if blend_histogram: |
|
hb_res = histogram_blend.blend(oa, ob, min_error_img, |
|
(1 - weight1), (1 - weight2)) |
|
|
|
else: |
|
|
|
tmpa = oa.astype(np.float32) |
|
tmpb = ob.astype(np.float32) |
|
hb_res = (1 - weight1) * tmpa + (1 - weight2) * tmpb |
|
|
|
|
|
|
|
|
|
if blend_gradient: |
|
res = poisson_fusion(hb_res, oa, ob, mask) |
|
else: |
|
res = hb_res |
|
|
|
cv2.imwrite(blend_out_path, res) |
|
end = time.time() |
|
print('others:', end - beg) |
|
|
|
|
|
def main(args): |
|
global MAX_PROCESS |
|
MAX_PROCESS = args.n_proc |
|
|
|
video_sequence = create_sequence(f'{args.name}', args.key_ind, args.key) |
|
if not args.ne: |
|
run_ebsynth(video_sequence) |
|
blend_histogram = True |
|
blend_gradient = args.ps |
|
for i in range(video_sequence.n_seq): |
|
process_seq(video_sequence, i, blend_histogram, blend_gradient) |
|
if args.output: |
|
frame_to_video(args.output, video_sequence.blending_dir, args.fps, |
|
False) |
|
if not args.tmp: |
|
video_sequence.remove_out_and_tmp() |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('name', type=str, help='Path to input video') |
|
parser.add_argument('--output', |
|
type=str, |
|
default=None, |
|
help='Path to output video') |
|
parser.add_argument('--fps', |
|
type=float, |
|
default=30, |
|
help='The FPS of output video') |
|
parser.add_argument("--key_ind", type=int, nargs='+', default=[1], help="key frame index") |
|
parser.add_argument('--key', |
|
type=str, |
|
default='keys0', |
|
help='The subfolder name of stylized key frames') |
|
parser.add_argument('--n_proc', |
|
type=int, |
|
default=8, |
|
help='The max process count') |
|
parser.add_argument('-ps', |
|
action='store_true', |
|
help='Use poisson gradient blending') |
|
parser.add_argument( |
|
'-ne', |
|
action='store_true', |
|
help='Do not run ebsynth (use previous ebsynth output)') |
|
parser.add_argument('-tmp', |
|
action='store_true', |
|
help='Keep temporary output') |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|