# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..  
# *************************************************************************
import os
import argparse
import numpy as np
# torch
import torch
from ema_pytorch import EMA
from einops import rearrange
import cv2
# utils
from utils.utils import set_seed, count_param, print_peak_memory
# model
import imageio
from model_lib.ControlNet.cldm.model import create_model
import copy
import glob
import imageio
from skimage.transform import resize
from skimage import img_as_ubyte
import face_alignment
import sys
from decord import VideoReader
from decord import cpu, gpu

TORCH_VERSION = torch.__version__.split(".")[0]
FP16_DTYPE = torch.float16
print(f"TORCH_VERSION={TORCH_VERSION} FP16_DTYPE={FP16_DTYPE}")

def extract_local_feature_from_single_img(img, fa, remove_local=False, real_tocrop=None, target_res = 512):
    device = img.device
    pred = img.permute([1, 2, 0]).detach().cpu().numpy()

    pred_lmks = img_as_ubyte(resize(pred, (256, 256)))

    try:
        lmks = fa.get_landmarks_from_image(pred_lmks, return_landmark_score=False)[0]
    except:
        print ('undetected faces!!')
        if real_tocrop is None:
            return torch.zeros_like(img) * 2 - 1., [196,196,320,320]
        return torch.zeros_like(img), [196,196,320,320]
    
    halfedge = 32
    left_eye_center = (np.clip(np.round(np.mean(lmks[43:48], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
    right_eye_center = (np.clip(np.round(np.mean(lmks[37:42], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
    mouth_center = (np.clip(np.round(np.mean(lmks[49:68], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)

    if real_tocrop is not None:
        pred = real_tocrop.permute([1, 2, 0]).detach().cpu().numpy()

    half_size = target_res // 8 #64
    if remove_local:
        local_viz = pred
        local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = 0
        local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = 0
        local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0]  + half_size] = 0        
    else:
        local_viz = np.zeros_like(pred)
        local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = pred[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size]
        local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = pred[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size]
        local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0]  + half_size] = pred[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size]

    local_viz = torch.from_numpy(local_viz).to(device)
    local_viz = local_viz.permute([2, 0, 1])
    if real_tocrop is None:
        local_viz = local_viz * 2 - 1.
    return local_viz

def find_best_frame_byheadpose_fa(source_image, driving_video, fa):
    input = img_as_ubyte(resize(source_image, (256, 256)))
    try:
        src_pose_array = fa.get_landmarks_from_image(input, return_landmark_score=False)[0]
    except:
        print ('undetected faces in the source image!!')
        src_pose_array = np.zeros((68,2))
    if len(src_pose_array) == 0:
        return 0
    min_diff = 1e8
    best_frame = 0

    for i in range(len(driving_video)):
        frame = img_as_ubyte(resize(driving_video[i], (256, 256)))
        try:
            drv_pose_array = fa.get_landmarks_from_image(frame, return_landmark_score=False)[0]
        except:
            print ('undetected faces in the %d-th driving image!!'%i)
            drv_pose_array = np.zeros((68,2))
        diff = np.sum(np.abs(np.array(src_pose_array)-np.array(drv_pose_array)))
        if diff < min_diff:
            best_frame = i
            min_diff = diff   
    
    return best_frame

def adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame=-1):
    if best_frame == -2:
        return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video]
    src = img_as_ubyte(resize(source_image[..., :3], (256, 256)))
    if  best_frame >= len(source_image):
        raise ValueError(
            f"please specify one frame in driving video of which the pose match best with the pose of source image"
        )

    if best_frame < 0:
        best_frame = find_best_frame_byheadpose_fa(src, driving_video, fa)

    print ('Best Frame: %d' % best_frame)
    driving = img_as_ubyte(resize(driving_video[best_frame], (256, 256)))

    src_lmks = fa.get_landmarks_from_image(src, return_landmark_score=False)
    drv_lmks = fa.get_landmarks_from_image(driving, return_landmark_score=False)

    if (src_lmks is None) or (drv_lmks is None):
        return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video]
    src_lmks = src_lmks[0]
    drv_lmks = drv_lmks[0]
    src_centers = np.mean(src_lmks, axis=0)
    drv_centers = np.mean(drv_lmks, axis=0)
    edge_src = (np.max(src_lmks, axis=0) - np.min(src_lmks, axis=0))*0.5
    edge_drv = (np.max(drv_lmks, axis=0) - np.min(drv_lmks, axis=0))*0.5

    #matching three points 
    src_point=np.array([[src_centers[0]-edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]-edge_src[0],src_centers[1]+edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]+edge_src[1]]]).astype(np.float32)
    dst_point=np.array([[drv_centers[0]-edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]-edge_drv[0],drv_centers[1]+edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]+edge_drv[1]]]).astype(np.float32)
   
    adjusted_driving_video = []
    adjusted_driving_video_hd = []
    
    for frame in driving_video:
        frame_ld = resize(frame, (nm_res, nm_res))
        frame_hd = resize(frame, (nmd_res, nmd_res))
        zoomed=cv2.warpAffine(frame_ld, cv2.getAffineTransform(dst_point[:3], src_point[:3]), (nm_res, nm_res))
        zoomed_hd=cv2.warpAffine(frame_hd, cv2.getAffineTransform(dst_point[:3] * 2, src_point[:3] * 2), (nmd_res, nmd_res))
        adjusted_driving_video.append(zoomed)
        adjusted_driving_video_hd.append(zoomed_hd)
    
    return adjusted_driving_video, adjusted_driving_video_hd

def x_portrait_data_prep(source_image_path, driving_video_path, device, best_frame_id=0, start_idx = 0, num_frames=0, skip=1, output_local=False, more_source_image_pattern="", target_resolution = 512):
    source_image = imageio.imread(source_image_path)
    if '.mp4' in driving_video_path:
        reader = imageio.get_reader(driving_video_path)
        fps = reader.get_meta_data()['fps']
        driving_video = []
        try:
            for im in reader:
                driving_video.append(im)
        except RuntimeError:
            pass
        reader.close()
    else:
        driving_video = [imageio.imread(driving_video_path)[...,:3]]
        fps = 1

    nmd_res = target_resolution
    nm_res = 256
    source_image_hd = resize(source_image, (nmd_res, nmd_res))[..., :3]

    if more_source_image_pattern:
        more_source_paths = glob.glob(more_source_image_pattern)
        more_sources_hd = []
        for more_source_path in more_source_paths:
            more_source_image = imageio.imread(more_source_path)
            more_source_image_hd = resize(more_source_image, (nmd_res, nmd_res))[..., :3]
            more_source_hd = torch.tensor(more_source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
            more_source_hd = more_source_hd.to(device)
            more_sources_hd.append(more_source_hd)
        more_sources_hd = torch.stack(more_sources_hd, dim = 1) 
    else:
        more_sources_hd = None

    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True, device='cuda')

    driving_video, driving_video_hd = adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame_id)

    if num_frames == 0:
        end_idx = len(driving_video)
    else:
        num_frames = min(len(driving_video), num_frames)
        end_idx = start_idx + num_frames * skip
    
    driving_video = driving_video[start_idx:end_idx][::skip]
    driving_video_hd = driving_video_hd[start_idx:end_idx][::skip]
    num_frames = len(driving_video)

    with torch.no_grad():
        real_source_hd = torch.tensor(source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        real_source_hd = real_source_hd.to(device)

        driving_hd = torch.tensor(np.array(driving_video_hd).astype(np.float32)).permute(0, 3, 1, 2).to(device)

        local_features = []
        raw_drivings=[]

        for frame_idx in range(0, num_frames):
            raw_drivings.append(driving_hd[frame_idx:frame_idx+1] * 2 - 1.)
            if output_local:
                local_feature_img = extract_local_feature_from_single_img(driving_hd[frame_idx], fa,target_res=nmd_res)
                local_features.append(local_feature_img)


    batch_data = {}
    batch_data['fps'] = fps
    real_source_hd = real_source_hd * 2 - 1
    batch_data['sources'] = real_source_hd[:, None, :, :, :].repeat([num_frames, 1, 1, 1, 1]) 
    if more_sources_hd is not None:
        more_sources_hd = more_sources_hd * 2 - 1
        batch_data['more_sources'] = more_sources_hd.repeat([num_frames, 1, 1, 1, 1])

    raw_drivings = torch.stack(raw_drivings, dim = 0)
    batch_data['conditions'] = raw_drivings
    if output_local:
        batch_data['local'] = torch.stack(local_features, dim = 0)

    return batch_data

# You can now use the modified state_dict without the deleted keys
def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
    print(f"Loading model state dict from {ckpt_path} ...")
    state_dict = torch.load(ckpt_path, map_location=map_location)
    state_dict = state_dict.get('state_dict', state_dict)
    if reinit_hint_block:
        print("Ignoring hint block parameters from checkpoint!")
        for k in list(state_dict.keys()):
            if k.startswith("control_model.input_hint_block"):
                state_dict.pop(k)
    model.load_state_dict(state_dict, strict=strict)
    del state_dict   

def get_cond_control(args, batch_data, control_type, device, start, end, model=None, batch_size=None, train=True, key=0):

    control_type = copy.deepcopy(control_type)
    vae_bs = 16
    if control_type == "appearance_pose_local_mm":
        src = batch_data['sources'][start:end, key].cuda()
        c_cat_list = batch_data['conditions'][start:end].cuda()
        cond_image = []
        for k in range(0, end-start, vae_bs):
            cond_image.append(model.get_first_stage_encoding(model.encode_first_stage(src[k:k+vae_bs])))
        cond_image = torch.concat(cond_image, dim=0)
        cond_img_cat = cond_image
        p_local = batch_data['local'][start:end].cuda()    
        print ('Total frames:{}'.format(cond_img_cat.shape))
        more_cond_imgs = []
        if 'more_sources' in batch_data:
            num_additional_cond_imgs = batch_data['more_sources'].shape[1]
            for i in range(num_additional_cond_imgs):
                m_cond_img = batch_data['more_sources'][start:end, i]
                m_cond_img = model.get_first_stage_encoding(model.encode_first_stage(m_cond_img))
                more_cond_imgs.append([m_cond_img.to(device)])

        return [cond_img_cat.to(device), c_cat_list, p_local, more_cond_imgs]    
    else:
        raise NotImplementedError(f"cond_type={control_type} not supported!")

def visualize_mm(args, name, batch_data, infer_model, nSample, local_image_dir, num_mix=4, preset_output_name=''):
    driving_video_name = os.path.basename(batch_data['video_name']).split('.')[0]
    source_name = os.path.basename(batch_data['source_name']).split('.')[0]

    if not os.path.exists(local_image_dir):
        os.mkdir(local_image_dir)

    uc_scale = args.uc_scale
    if preset_output_name:
        preset_output_name = preset_output_name.split('.')[0]+'.mp4'
        output_path = f"{local_image_dir}/{preset_output_name}"
    else:
        output_path = f"{local_image_dir}/{name}_{args.control_type}_uc{uc_scale}_{source_name}_by_{driving_video_name}_mix{num_mix}.mp4"

    infer_model.eval()

    gene_img_list = []
    
    _, _, ch, h, w = batch_data['sources'].shape

    vae_bs = 16

    if args.initial_facevid2vid_results:
        facevid2vid = []
        facevid2vid_results = VideoReader(args.initial_facevid2vid_results, ctx=cpu(0))
        for frame_id in range(len(facevid2vid_results)):
            frame = cv2.resize(facevid2vid_results[frame_id].asnumpy(),(512,512)) / 255
            facevid2vid.append(torch.from_numpy(frame * 2 - 1).permute(2,0,1))
        cond = torch.stack(facevid2vid)[:nSample].float().to(args.device)
        pre_noise=[]
        for i in range(0, nSample, vae_bs):
            pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs])))
        pre_noise = torch.cat(pre_noise, dim=0)
        pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device))
    else:
        cond = batch_data['sources'][:nSample].reshape([-1, ch, h, w])
        pre_noise=[]
        for i in range(0, nSample, vae_bs):
            pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs])))
        pre_noise = torch.cat(pre_noise, dim=0)
        pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device))

    text = ["" for _ in range(nSample)]
    
    all_c_cat = get_cond_control(args, batch_data, args.control_type, args.device, start=0, end=nSample, model=infer_model, train=False)
    cond_img_cat = [all_c_cat[0]]
    pose_cond_list = [rearrange(all_c_cat[1], "b f c h w -> (b f) c h w")]
    local_pose_cond_list = [all_c_cat[2]]

    c_cross = infer_model.get_learned_conditioning(text)[:nSample]
    uc_cross = infer_model.get_unconditional_conditioning(nSample)

    c = {"c_crossattn": [c_cross], "image_control": cond_img_cat}
    if "appearance_pose" in args.control_type:
        c['c_concat'] = pose_cond_list
    if "appearance_pose_local" in args.control_type:
        c["local_c_concat"] = local_pose_cond_list
    
    if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0:
        c['more_image_control'] = all_c_cat[3]

    if args.control_mode == "controlnet_important":
        uc = {"c_crossattn": [uc_cross]}
    else:
        uc = {"c_crossattn": [uc_cross], "image_control":cond_img_cat}

    if "appearance_pose" in args.control_type:
        uc['c_concat'] = [torch.zeros_like(pose_cond_list[0])]

    if "appearance_pose_local" in args.control_type:
        uc["local_c_concat"] = [torch.zeros_like(local_pose_cond_list[0])]

    if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0:
        uc['more_image_control'] = all_c_cat[3]

    if args.wonoise:
        c['wonoise'] = True
        uc['wonoise'] = True
    else:
        c['wonoise'] = False
        uc['wonoise'] = False
        
    noise = pre_noise.to(c_cross.device)

    with torch.cuda.amp.autocast(enabled=args.use_fp16, dtype=FP16_DTYPE):
        infer_model.to(args.device)
        infer_model.eval()

        gene_img, _ = infer_model.sample_log(cond=c,
                                    batch_size=args.num_drivings, ddim=True,
                                    ddim_steps=args.ddim_steps, eta=args.eta,
                                    unconditional_guidance_scale=uc_scale,
                                    unconditional_conditioning=uc,
                                    inpaint=None,
                                    x_T=noise,
                                    num_overlap=num_mix,
                                    )

        for i in range(0, nSample, vae_bs):
            gene_img_part = infer_model.decode_first_stage( gene_img[i:i+vae_bs] )
            gene_img_list.append(gene_img_part.float().clamp(-1, 1).cpu())

    _, c, h, w = gene_img_list[0].shape  

    cond_image = batch_data["conditions"].reshape([-1,c,h,w])[:nSample].cpu()
    l_cond_image = batch_data["local"].reshape([-1,c,h,w])[:nSample].cpu()
    orig_image = batch_data["sources"][:nSample, 0].cpu()

    output_img = torch.cat(gene_img_list + [cond_image.cpu()]+[l_cond_image.cpu()]+[orig_image.cpu()]).float().clamp(-1,1).add(1).mul(0.5)

    num_cols = 4
    output_img = output_img.reshape([num_cols, 1, nSample, c, h, w]).permute([1, 0, 2, 3, 4,5])

    output_img = output_img.permute([2, 3, 0, 4, 1, 5]).reshape([-1, c,  h,  num_cols * w])
    output_img = torch.permute(output_img, [0, 2, 3, 1])
    
    output_img = output_img.data.cpu().numpy()
    output_img = img_as_ubyte(output_img)
    imageio.mimsave(output_path, output_img[:,:,:512], fps=batch_data['fps'], quality=10, pixelformat='yuv420p', codec='libx264')

def main(args):
    
    # ******************************
    # initialize training
    # ******************************
    args.world_size = 1
    args.local_rank = 0
    args.rank = 0
    args.device = torch.device("cuda", args.local_rank)

    # set seed for reproducibility
    set_seed(args.seed)

    # ******************************
    # create model
    # ******************************
    model = create_model(args.model_config).cpu()
    model.sd_locked = args.sd_locked
    model.only_mid_control = args.only_mid_control
    model.to(args.local_rank)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    if args.local_rank == 0:
        print('Total base  parameters {:.02f}M'.format(count_param([model])))
    if args.ema_rate is not None and args.ema_rate > 0 and args.rank == 0:
        print(f"Creating EMA model at ema_rate={args.ema_rate}")
        model_ema = EMA(model, beta=args.ema_rate, update_after_step=0, update_every=1)
    else:
        model_ema = None

    # ******************************
    # load pre-trained models
    # ******************************
    if args.resume_dir is not None:
        if args.local_rank == 0:
            load_state_dict(model, args.resume_dir, strict=False)
    else:
        print('please privide the correct resume_dir!')
        exit()
    
    # ******************************
    # create DDP model
    # ******************************
    if args.compile and TORCH_VERSION == "2":
        model = torch.compile(model)
    
    torch.cuda.set_device(args.local_rank)
    print_peak_memory("Max memory allocated after creating DDP", args.local_rank)
    infer_model = model.module if hasattr(model, "module") else model

    with torch.no_grad():
        driving_videos = glob.glob(args.driving_video)
        for driving_video in driving_videos:
            print ('working on {}'.format(os.path.basename(driving_video)))
            infer_batch_data = x_portrait_data_prep(args.source_image, driving_video, args.device, args.best_frame, start_idx = args.start_idx, num_frames = args.out_frames, skip=args.skip, output_local=True)
            infer_batch_data['video_name'] = os.path.basename(driving_video)
            infer_batch_data['source_name'] = args.source_image
            nSample = infer_batch_data['sources'].shape[0]
            visualize_mm(args, "inference", infer_batch_data, infer_model, nSample=nSample, local_image_dir=args.output_dir, num_mix=args.num_mix)


if __name__ == "__main__":

    str2bool = lambda arg: bool(int(arg))
    parser = argparse.ArgumentParser(description='Control Net training')
    ## Model
    parser.add_argument('--model_config', type=str, default="model_lib/ControlNet/models/cldm_v15_video_appearance.yaml",
                        help="The path of model config file")
    parser.add_argument('--reinit_hint_block', action='store_true', default=False,
                        help="Re-initialize hint blocks for channel mis-match")
    parser.add_argument('--sd_locked', type =str2bool, default=True,
                        help='Freeze parameters in original stable-diffusion decoder')
    parser.add_argument('--only_mid_control', type =str2bool, default=False,
                        help='Only control middle blocks')
    parser.add_argument('--control_type', type=str, default="appearance_pose_local_mm",
                        help='The type of conditioning')
    parser.add_argument("--control_mode", type=str, default="controlnet_important",
                        help="Set controlnet is more important or balance.")
    parser.add_argument('--wonoise', action='store_false', default=True,
                        help='Use with referenceonly, remove adding noise on reference image')
 
    ## Training
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument('--seed', type=int, default=42, 
                        help='random seed for initialization')
    parser.add_argument('--use_fp16', action='store_false', default=True,
                        help='Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit')
    parser.add_argument('--compile', type=str2bool, default=False,
                        help='compile model (for torch 2)')
    parser.add_argument('--eta', type = float, default = 0.0,
                        help='eta during DDIM Sampling')
    parser.add_argument('--ema_rate', type = float, default = 0,
                        help='rate for ema')
    ## inference
    parser.add_argument("--initial_facevid2vid_results", type=str, default=None,
                    help="facevid2vid results for noise initialization")
    parser.add_argument('--ddim_steps', type = int, default = 1,
                        help='denoising steps')
    parser.add_argument('--uc_scale', type = int, default = 5,
                        help='cfg')
    parser.add_argument("--num_drivings", type = int, default = 16,
                        help="Number of driving images in a single sequence of video.")
    parser.add_argument("--output_dir", type=str, default=None, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--resume_dir", type=str, default=None,
                        help="The resume directory where the model checkpoints will be loaded.")
    parser.add_argument("--source_image", type=str, default="",
                        help="The source image for neural motion.")                  
    parser.add_argument("--more_source_image_pattern", type=str, default="",
                        help="The source image for neural motion.")   
    parser.add_argument("--driving_video", type=str, default="",
                        help="The source image mask for neural motion.")                 
    parser.add_argument('--best_frame', type=int, default=0,
                        help='best matching frame index')     
    parser.add_argument('--start_idx', type=int, default=0,
                        help='starting frame index')   
    parser.add_argument('--skip', type=int, default=1,
                        help='skip frame')  
    parser.add_argument('--num_mix', type=int, default=4,
                        help='num overlapping frames')  
    parser.add_argument('--out_frames', type=int, default=0,
                        help='num frames')  
    args = parser.parse_args()

    main(args)