import torch import imageio import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation from skimage.transform import resize import warnings import os from demo import make_animation from skimage import img_as_ubyte from demo import load_checkpoints import gradio def inference(source_image_path='./assets/source.png', driving_video_path='./assets/driving.mp4', dataset_name="vox"): # edit the config device = torch.device('cpu') # dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif'] # source_image_path = './assets/source.png' # driving_video_path = './assets/driving.mp4' output_video_path = './generated.mp4' pixel = 256 # for vox, taichi and mgif, the resolution is 256*256 if (dataset_name == 'ted'): # for ted, the resolution is 384*384 pixel = 384 config_path = f'config/{dataset_name}-{pixel}.yaml' checkpoint_path = f'checkpoints/{dataset_name}.pth.tar' predict_mode = 'relative' # ['standard', 'relative', 'avd'] warnings.filterwarnings("ignore") source_image = imageio.imread(source_image_path) reader = imageio.get_reader(driving_video_path) source_image = resize(source_image, (pixel, pixel))[..., :3] fps = reader.get_meta_data()['fps'] driving_video = [] try: for im in reader: driving_video.append(im) except RuntimeError: pass reader.close() driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video] # driving_video = driving_video[:10] def display(source, driving, generated=None) -> animation.ArtistAnimation: fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6)) ims = [] for i in range(len(driving)): cols = [source] cols.append(driving[i]) if generated is not None: cols.append(generated[i]) im = plt.imshow(np.concatenate(cols, axis=1), animated=True) plt.axis('off') ims.append([im]) ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000) # plt.show() plt.close() return ani inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path, device=device) predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device=device, mode=predict_mode) # save resulting video imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps) ani = display(source_image, driving_video, predictions) ani.save('animation.mp4', writer='imagemagick', fps=60) return 'animation.mp4' demo = gradio.Interface( fn=inference, inputs=[ gradio.inputs.Image(type="filepath", label="Input image"), gradio.inputs.Video(label="Input video"), gradio.inputs.Dropdown(['vox', 'taichi', 'ted', 'mgif'], type="value", default="vox", label="Model", optional=False), ], outputs=["video"], examples=[ ['./assets/source.png', './assets/driving.mp4', "vox"], ['./assets/source_ted.png', './assets/driving_ted.mp4', "ted"], ], ) if __name__ == "__main__": demo.launch()