import os
import uuid
import glob
import shutil
from pathlib import Path
from multiprocessing.pool import Pool

import gradio as gr
import torch
from torchvision import transforms

import cv2
import numpy as np
from PIL import Image
import tqdm

from modules.networks.faceshifter import FSGenerator
from inference.alignment import norm_crop, norm_crop_with_M, paste_back
from inference.utils import save, get_5_from_98, get_detector, get_lmk
from third_party.PIPNet.lib.tools import get_lmk_model, demo_image
from inference.landmark_smooth import kalman_filter_landmark, savgol_filter_landmark
from inference.tricks import Trick

make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))


fs_model_name = 'faceshifter'
in_size = 256

mouth_net_param = {
    "use": True,
    "feature_dim": 128,
    "crop_param": (28, 56, 84, 112),
    "weight_path": make_abs_path("./weights/arcface/mouth_net_28_56_84_112.pth"),
}
trick = Trick()

T = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(0.5, 0.5),
        ]
    )
tensor2pil_transform = transforms.ToPILImage()


def extract_generator(ckpt: str, pt: str):
    print(f'[extract_generator] loading ckpt...')
    from trainer.faceshifter.faceshifter_pl import FaceshifterPL512, FaceshifterPL
    import yaml
    with open(make_abs_path('../../trainer/faceshifter/config.yaml'), 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    config['mouth_net'] = mouth_net_param

    if in_size == 256:
        net = FaceshifterPL(n_layers=3, num_D=3, config=config)
    elif in_size == 512:
        net = FaceshifterPL512(n_layers=3, num_D=3, config=config, verbose=False)
    else:
        raise ValueError('Not supported in_size.')
    checkpoint = torch.load(ckpt, map_location="cpu", )
    net.load_state_dict(checkpoint["state_dict"], strict=False)
    net.eval()

    G = net.generator
    torch.save(G.state_dict(), pt)
    print(f'[extract_generator] extracted from {ckpt}, pth saved to {pt}')


''' load model '''
if fs_model_name == 'faceshifter':
    pt_path = make_abs_path("./weights/extracted/G_mouth1_t38_post.pth")
    # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_6.pth")
    # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_6/epoch=3-step=128999.ckpt"
    # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_4.pth")
    # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_4/epoch=2-step=185999.ckpt"
    if not os.path.exists(pt_path) or 't512' in pt_path:
        extract_generator(ckpt_path, pt_path)
    fs_model = FSGenerator(
        make_abs_path("./weights/arcface/ms1mv3_arcface_r100_fp16/backbone.pth"),
        mouth_net_param=mouth_net_param,
        in_size=in_size,
        downup=in_size == 512,
    )
    fs_model.load_state_dict(torch.load(pt_path, "cpu"), strict=True)
    fs_model.eval()

    @torch.no_grad()
    def infer_batch_to_img(i_s, i_t, post: bool = False):
        i_r = fs_model(i_s, i_t)[0]  # x, id_vector, att

        if post:
            target_hair_mask = trick.get_any_mask(i_t, par=[0, 17])
            target_hair_mask = trick.smooth_mask(target_hair_mask)
            i_r = target_hair_mask * i_t + (target_hair_mask * (-1.) + 1.) * i_r  # torch 1.12.0
            i_r = trick.finetune_mouth(i_s, i_t, i_r) if in_size == 256 else i_r

        img_r = trick.tensor_to_arr(i_r)[0]
        return img_r

elif fs_model_name == 'simswap_triplet' or fs_model_name == 'simswap_vanilla':
    from modules.networks.simswap import Generator_Adain_Upsample
    sw_model = Generator_Adain_Upsample(
        input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False,
        mouth_net_param=mouth_net_param
    )
    if fs_model_name == 'simswap_triplet':
        pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_st5.pth")
        ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/"
                                  "simswap_triplet_5/epoch=12-step=782999.ckpt")
    elif fs_model_name == 'simswap_vanilla':
        pt_path = make_abs_path("../ffplus/extracted_ckpt/G_tmp_sv4_off.pth")
        ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/"
                                  "simswap_vanilla_4/epoch=694-step=1487999.ckpt")
    else:
        pt_path = None
        ckpt_path = None
    sw_model.load_state_dict(torch.load(pt_path, "cpu"), strict=False)
    sw_model.eval()
    fs_model = sw_model

    from trainer.simswap.simswap_pl import SimSwapPL
    import yaml
    with open(make_abs_path('../../trainer/simswap/config.yaml'), 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    config['mouth_net'] = mouth_net_param
    net = SimSwapPL(config=config, use_official_arc='off' in pt_path)

    checkpoint = torch.load(ckpt_path, map_location="cpu")
    net.load_state_dict(checkpoint["state_dict"], strict=False)
    net.eval()
    sw_mouth_net = net.mouth_net  # maybe None
    sw_netArc = net.netArc
    fs_model = fs_model.cuda()
    sw_mouth_net = sw_mouth_net.cuda() if sw_mouth_net is not None else sw_mouth_net
    sw_netArc = sw_netArc.cuda()

    @torch.no_grad()
    def infer_batch_to_img(i_s, i_t, post: bool = False):
        i_r = fs_model(source=i_s, target=i_t, net_arc=sw_netArc, mouth_net=sw_mouth_net,)
        if post:
            target_hair_mask = trick.get_any_mask(i_t, par=[0, 17])
            target_hair_mask = trick.smooth_mask(target_hair_mask)
            i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r
        i_r = i_r.clamp(-1, 1)
        i_r = trick.tensor_to_arr(i_r)[0]
        return i_r

elif fs_model_name == 'simswap_official':
    from simswap.image_infer import SimSwapOfficialImageInfer
    fs_model = SimSwapOfficialImageInfer()
    pt_path = 'Simswap Official'
    mouth_net_param = {
        "use": False
    }

    @torch.no_grad()
    def infer_batch_to_img(i_s, i_t):
        i_r = fs_model.image_infer(source_tensor=i_s, target_tensor=i_t)
        i_r = i_r.clamp(-1, 1)
        return i_r

else:
    raise ValueError('Not supported fs_model_name.')


print(f'[demo] model loaded from {pt_path}')


def swap_image(
    source_image,
    target_path,
    out_path,
    transform,
    G,
    align_source="arcface",
    align_target="set1",
    gpu_mode=True,
    paste_back=True,
    use_post=False,
    use_gpen=False,
    in_size=256,
):
    name = target_path.split("/")[-1]
    name = "out_" + name
    if isinstance(G, torch.nn.Module):
        G.eval()
        if gpu_mode:
            G = G.cuda()

    device = torch.device(0) if gpu_mode else torch.device('cpu')
    source_img = np.array(Image.open(source_image).convert("RGB"))
    net, detector = get_lmk_model()
    lmk = get_5_from_98(demo_image(source_img, net, detector, device=device)[0])
    source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0)
    source_img = transform(source_img).unsqueeze(0)

    target = np.array(Image.open(target_path).convert("RGB"))
    original_target = target.copy()
    lmk = get_5_from_98(demo_image(target, net, detector, device=device)[0])
    target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0)
    target = transform(target).unsqueeze(0)
    if gpu_mode:
        target = target.cuda()
        source_img = source_img.cuda()

    cv2.imwrite('cropped_source.png', trick.tensor_to_arr(source_img)[0, :, :, ::-1])
    cv2.imwrite('cropped_target.png', trick.tensor_to_arr(target)[0, :, :, ::-1])

    # both inputs should be 512
    result = infer_batch_to_img(source_img, target, post=use_post)

    cv2.imwrite('result.png', result[:, :, ::-1])

    os.makedirs(out_path, exist_ok=True)
    Image.fromarray(result.astype(np.uint8)).save(os.path.join(out_path, name))
    save((result, M, original_target, os.path.join(out_path, "paste_back_" + name), None),
         trick=trick, use_gpen=use_gpen)


def process_video(
    source_image,
    target_path,
    out_path,
    transform,
    G,
    align_source="arcface",
    align_target="set1",
    gpu_mode=True,
    frames=9999999,
    use_tddfav2=False,
    landmark_smooth="kalman",
    use_gpen=False,
):
    if isinstance(G, torch.nn.Module):
        G.eval()
        if gpu_mode:
            G = G.cuda()
    device = torch.device(0) if gpu_mode else torch.device('cpu')
    ''' Target video to frames (.png) '''
    fps = 25.0
    if not os.path.isdir(target_path):
        vidcap = cv2.VideoCapture(target_path)
        fps = vidcap.get(cv2.CAP_PROP_FPS)
        try:
            for match in glob.glob(os.path.join("./tmp/", "*.png")):
                os.remove(match)
            for match in glob.glob(os.path.join(out_path, "*.png")):
                os.remove(match)
        except Exception as e:
            print(e)
        os.makedirs("./tmp/", exist_ok=True)
        os.system(
            f"ffmpeg -i {target_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0  ./tmp/frame_%05d.png"
        )
        target_path = "./tmp/"
    globbed_images = sorted(glob.glob(os.path.join(target_path, "*.png")))
    ''' Get target landmarks '''
    print('[Extracting target landmarks...]')
    if not use_tddfav2:
        align_net, align_detector = get_lmk_model()
    else:
        align_net, align_detector = get_detector(gpu_mode=gpu_mode)
    target_lmks = []
    for frame_path in tqdm.tqdm(globbed_images):
        target = np.array(Image.open(frame_path).convert("RGB"))
        lmk = demo_image(target, align_net, align_detector, device=device)
        lmk = lmk[0]
        target_lmks.append(lmk)
    ''' Landmark smoothing '''
    target_lmks = np.array(target_lmks, np.float32)  # (#frames, 98, 2)
    if landmark_smooth == 'kalman':
        target_lmks = kalman_filter_landmark(target_lmks,
                                             process_noise=0.01,
                                             measure_noise=0.01).astype(np.int32)
    elif landmark_smooth == 'savgol':
        target_lmks = savgol_filter_landmark(target_lmks).astype(np.int32)
    elif landmark_smooth == 'cancel':
        target_lmks = target_lmks.astype(np.int32)
    else:
        raise KeyError('Not supported landmark_smooth choice')
    ''' Crop source image '''
    source_img = np.array(Image.open(source_image).convert("RGB"))
    if not use_tddfav2:
        lmk = get_5_from_98(demo_image(source_img, align_net, align_detector, device=device)[0])
    else:
        lmk = get_lmk(source_img, align_net, align_detector)
    source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0)
    source_img = transform(source_img).unsqueeze(0)
    if gpu_mode:
        source_img = source_img.cuda()
    ''' Process by frames '''
    targets = []
    t_facial_masks = []
    Ms = []
    original_frames = []
    names = []
    count = 0
    for image in tqdm.tqdm(globbed_images):
        names.append(os.path.join(out_path, Path(image).name))
        target = np.array(Image.open(image).convert("RGB"))
        original_frames.append(target)
        ''' Crop target frames '''
        lmk = get_5_from_98(target_lmks[count])
        target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0)
        target = transform(target).unsqueeze(0)  # in [-1,1]
        if gpu_mode:
            target = target.cuda()
        ''' Finetune paste masks '''
        target_facial_mask = trick.get_any_mask(target,
                                                par=[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]).squeeze()  # in [0,1]
        target_facial_mask = target_facial_mask.cpu().numpy().astype(np.float32)
        target_facial_mask = trick.finetune_mask(target_facial_mask, target_lmks)  # in [0,1]
        t_facial_masks.append(target_facial_mask)
        ''' Face swapping '''
        with torch.no_grad():
            if 'faceshifter' in fs_model_name:
                output = fs_model(source_img, target)[0]
                target_hair_mask = trick.get_any_mask(target, par=[0, 17])
                target_hair_mask = trick.smooth_mask(target_hair_mask)
                # print(output.shape, target.shape, target_hair_mask.shape)
                output = target_hair_mask * target + (target_hair_mask * (-1.) + 1.) * output
                output = trick.finetune_mouth(source_img, target, output)
            elif 'simswap' in fs_model_name and 'official' not in fs_model_name:
                output = fs_model(source=source_img, target=target,
                                  net_arc=sw_netArc, mouth_net=sw_mouth_net,)
                if 'vanilla' not in fs_model_name:
                    target_hair_mask = trick.get_any_mask(target, par=[0, 17])
                    target_hair_mask = trick.smooth_mask(target_hair_mask)
                    output = target_hair_mask * target + (target_hair_mask * (-1.) + 1.) * output
                    output = trick.finetune_mouth(source_img, target, output)
                output = output.clamp(-1, 1)
            elif 'simswap_official' in fs_model_name:
                output = fs_model.image_infer(source_tensor=source_img, target_tensor=target)
                output = output.clamp(-1, 1)
            if isinstance(output, tuple):
                target = output[0][0] * 0.5 + 0.5
            else:
                target = output[0] * 0.5 + 0.5
        targets.append(trick.gpen(np.array(tensor2pil_transform(target)), use_gpen=use_gpen))
        Ms.append(M)
        count += 1
        if count > frames:
            break
    os.makedirs(out_path, exist_ok=True)
    return targets, t_facial_masks, Ms, original_frames, names, fps


def swap_image_gr(img1, img2, use_post=False, use_gpen=False, ):
    root_dir = make_abs_path("./online_data")
    req_id = uuid.uuid1().hex
    data_dir = os.path.join(root_dir, req_id)
    os.makedirs(data_dir, exist_ok=True)
    source_path = os.path.join(data_dir, "source.png")
    target_path = os.path.join(data_dir, "target.png")
    filename = "paste_back_out_target.png"
    out_path = os.path.join(data_dir, filename)
    cv2.imwrite(source_path, img1[:, :, ::-1])
    cv2.imwrite(target_path, img2[:, :, ::-1])
    swap_image(
        source_path,
        target_path,
        data_dir,
        T,
        fs_model,
        gpu_mode=use_gpu,
        align_target='ffhq',
        align_source='ffhq',
        use_post=use_post,
        use_gpen=use_gpen,
        in_size=in_size,
    )
    out = cv2.imread(out_path)[..., ::-1]
    return out


def swap_video_gr(img1, target_path, use_gpen=False, frames=9999999):
    root_dir = make_abs_path("./online_data")
    req_id = uuid.uuid1().hex
    data_dir = os.path.join(root_dir, req_id)
    os.makedirs(data_dir, exist_ok=True)
    source_path = os.path.join(data_dir, "source.png")
    cv2.imwrite(source_path, img1[:, :, ::-1])
    out_dir = os.path.join(data_dir, "out")
    out_name = "output.mp4"
    targets, t_facial_masks, Ms, original_frames, names, fps = process_video(
        source_path,
        target_path,
        out_dir,
        T,
        fs_model,
        gpu_mode=use_gpu,
        frames=frames,
        align_target='ffhq',
        align_source='ffhq',
        use_tddfav2=False,
        use_gpen=use_gpen,
    )

    pool_process = 170
    audio = True
    concat = False

    if pool_process <= 1:
        for target, M, original_target, name, t_facial_mask in tqdm.tqdm(
                zip(targets, Ms, original_frames, names, t_facial_masks)
        ):
            if M is None or target is None:
                Image.fromarray(original_target.astype(np.uint8)).save(name)
                continue
            Image.fromarray(paste_back(np.array(target), M, original_target, t_facial_mask)).save(name)
    else:
        with Pool(pool_process) as pool:
            pool.map(save, zip(targets, Ms, original_frames, names, t_facial_masks))

    video_save_path = os.path.join(out_dir, out_name)
    if audio:
        print("use audio")
        os.system(
            f"ffmpeg  -y -r {fps} -i {out_dir}/frame_%05d.png -i {target_path}"
            f" -map 0:v:0 -map 1:a:0? -c:a copy -c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p  {video_save_path}"
        )
    else:
        print("no audio")
        os.system(
            f"ffmpeg  -y -r {fps} -i ./tmp/frame_%05d.png "
            f"-c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}"
        )
    # ffmpeg -i left.mp4 -i right.mp4 -filter_complex hstack output.mp4
    if concat:
        concat_video_save_path = os.path.join(out_dir, "concat_" + out_name)
        os.system(
            f"ffmpeg -y  -i {target_path}  -i {video_save_path} -filter_complex hstack {concat_video_save_path}"
        )
    # delete tmp file
    shutil.rmtree("./tmp/")
    for match in glob.glob(os.path.join(out_dir, "*.png")):
        os.remove(match)
    print(video_save_path)
    return video_save_path


css = """
#warning {background-color: #FFCCCB} 
.feedback textarea {font-size: 24px !important}
.run_button {background-color: orange} 
"""


if __name__ == "__main__":
    use_gpu = torch.cuda.is_available()

    with gr.Blocks() as demo:
        gr.Markdown("SuperSwap")

        with gr.Tab("Image"):
            with gr.Row():
                with gr.Column(scale=3):
                    image1_input = gr.Image(label='source')
                    image2_input = gr.Image(label='target')
                    image_use_post = gr.Checkbox(label="Post-Process")
                    image_use_gpen = gr.Checkbox(label="Super Resolution")
                with gr.Column(scale=2):
                    image_output = gr.Image()
                    image_button = gr.Button("Run: Face Swapping", elem_classes="run_button")
        with gr.Tab("Video"):
            with gr.Row():
                with gr.Column(scale=3):
                    image3_input = gr.Image(label='source')
                    video_input = gr.Video(label='target')
                with gr.Column(scale=2):
                    video_output = gr.Video()
                    video_use_gpen = gr.Checkbox(label="Super Resolution")
                    video_button = gr.Button("Run: Face Swapping", elem_classes="run_button")
        image_button.click(
            swap_image_gr,
            inputs=[image1_input, image2_input, image_use_post, image_use_gpen],
            outputs=image_output,
        )
        video_button.click(
            swap_video_gr,
            inputs=[image3_input, video_input, video_use_gpen],
            outputs=video_output,
        )

    demo.launch(server_name="0.0.0.0", server_port=7860)