import os, sys, time, traceback
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))

import cv2
import numpy as np
from PIL import Image, ImageSequence
from einops import rearrange
import torch

from infer.utils import seed_everything, timing_decorator
from infer.utils import get_parameter_number, set_parameter_grad_false

from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo

from third_party.gen_baking import back_projection
from third_party.dust3r_utils import infer_warp_mesh_img
from svrm.ldm.vis_util import render_func


class MeshBaker:
    def __init__(
        self, 
        align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
        device = "cuda:0", 
        align_times = 1,
        iou_thresh = 0.8, 
        force_baking_ele_list = None,
        save_memory = False
    ):
        self.device = device
        self.save_memory = save_memory
        self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
        self.align_model = self.align_model if save_memory else self.align_model.to(device)
        self.align_times = align_times
        self.align_model.eval()
        self.iou_thresh = iou_thresh
        self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
        self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
        set_parameter_grad_false(self.align_model)
        print('baking align model', get_parameter_number(self.align_model))
    
    def align_and_check(self, src, dst, align_times=3):
        try:
            st = time.time()
            best_baking_flag = False
            best_aligned_image = aligned_image = src
            best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
            for i in range(align_times):
                aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
                aligned_image = Image.fromarray(aligned_image)
                print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
                if info['mask_iou'] > best_info['mask_iou']:
                    best_aligned_image, best_info = aligned_image, info
                if info['mask_iou'] < self.iou_thresh:
                    break
            print(f"Best Baking Info:{best_info['mask_iou']}")
            best_baking_flag = best_info['mask_iou'] > self.iou_thresh
            return best_aligned_image, best_info, best_baking_flag
        except Exception as e:
            print(f"Error processing image: {e}")
            traceback.print_exc()
            return None, None, None
        
    @timing_decorator("baking mesh")
    def __call__(self, *args, **kwargs):
        if self.save_memory:
            self.align_model = self.align_model.to(self.device)
            torch.cuda.empty_cache()
            res = self.call(*args, **kwargs)
            self.align_model = self.align_model.to("cpu")
        else:
            res = self.call(*args, **kwargs)
        torch.cuda.empty_cache()
        return res
    
    def call(self, save_folder):
        obj_path         = os.path.join(save_folder, "mesh.obj")
        raw_texture_path = os.path.join(save_folder, "texture.png")
        views_pil        = os.path.join(save_folder, "views.jpg")
        views_gif        = os.path.join(save_folder, "views.gif")
        cond_pil         = os.path.join(save_folder, "img_nobg.png")

        if os.path.exists(views_pil):
            views_pil = Image.open(views_pil)
            views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
            views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]] 
            cond_pil = Image.open(cond_pil).resize((512,512))
        elif os.path.exists(views_gif):
            views_gif_pil = Image.open(views_gif)
            views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
            cond_pil, views = views[0], views[1:]
        else:
            raise FileNotFoundError("views file not found")
                
        rendered_views = render_func(obj_path, elev=0, n_views=2)
        
        for ele_idx, ele in enumerate([0, 180]):
            
            if ele == 0:
                aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
                aligned_cond.save(save_folder + f'/aligned_cond.jpg')
        
                aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
                aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
                
                if info['mask_iou'] < cond_info['mask_iou']:
                    print("Using Cond Image to bake front view")
                    aligned_img = aligned_cond
                    info = cond_info
                need_baking = info['mask_iou'] > self.iou_thresh
            else:
                aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
                aligned_img.save(save_folder + f'/aligned_{ele}.jpg')

            if need_baking or (ele in self.force_baking_ele_list):
                st = time.time()
                view1_res = back_projection(
                    obj_file = obj_path,
                    init_texture_file = raw_texture_path,
                    front_view_file = aligned_img,
                    dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
                    render_resolution = aligned_img.size[0], 
                    uv_resolution = 1024,
                    views = [[0, ele]],
                    device = self.device
                )
                print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
                obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
                raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
            else:
                print(f"Skip view_{ele_idx} elevation_{ele} baking")

        print("Baking Finished")
        return obj_path
    

if __name__ == "__main__":
    baker = MeshBaker()
    obj_path = baker("./outputs/test")
    print(obj_path)