import gradio as gr from tools import Inference, Matting, log from omegaconf import OmegaConf import os import sys import numpy as np import torchvision.transforms.functional as tf from PIL import Image args = OmegaConf.load(os.path.join(f"./config/test.yaml")) global_comp = None global_mask = None log("Model loading") phnet = Inference(**args) stylematte = Matting(**args) log("Model loaded") def harmonize(comp, mask): log("Inference started") if comp is None or mask is None: log("Empty source") return np.zeros((16, 16, 3)) comp = comp.convert('RGB') mask = mask.convert('1') in_shape = comp.size[::-1] comp = tf.resize(comp, [args.image_size, args.image_size]) mask = tf.resize(mask, [args.image_size, args.image_size]) compt = tf.to_tensor(comp) maskt = tf.to_tensor(mask) res = phnet.harmonize(compt, maskt) res = tf.resize(res, in_shape) log("Inference finished") return np.uint8((res*255)[0].permute(1, 2, 0).numpy()) def extract_matte(img, back): mask, fg = stylematte.extract(img) fg_pil = Image.fromarray(np.uint8(fg)) composite = fg + (1 - mask[:, :, None]) * \ np.array(back.resize(mask.shape[::-1])) composite_pil = Image.fromarray(np.uint8(composite)) global_comp = composite_pil global_mask = mask return [composite_pil, mask, fg_pil] def css(height=3, scale=2): return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}" with gr.Blocks() as demo: gr.Markdown( """ # Welcome to portrait transfer demo app! Select source portrait image and new background. """) btn_compose = gr.Button(value="Compose") with gr.Row(): input_ui = gr.Image( type="numpy", label='Source image to extract foreground') back_ui = gr.Image(type="pil", label='The new background') gr.Examples( examples=[["./assets/comp.jpg", "./assets/back.jpg"]], inputs=[input_ui, back_ui], ) gr.Markdown( """ ## Resulting alpha matte and extracted foreground. """) with gr.Row(): matte_ui = gr.Image(type="pil", label='Alpha matte') fg_ui = gr.Image(type="pil", image_mode='RGBA', label='Extracted foreground') gr.Markdown( """ ## Click the button and compare the composite with the harmonized version. """) btn_harmonize = gr.Button(value="Harmonize composite") with gr.Row(): composite_ui = gr.Image(type="pil", label='Composite') harmonized_ui = gr.Image( type="pil", label='Harmonized composite', css=css(3, 3)) btn_compose.click(extract_matte, inputs=[input_ui, back_ui], outputs=[ composite_ui, matte_ui, fg_ui]) btn_harmonize.click(harmonize, inputs=[ composite_ui, matte_ui], outputs=[harmonized_ui]) log("Interface created") demo.launch(share=True)