from modules.loader.module_loader import GenericModuleLoader
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
import torch
from modules.params.diffusion.inference_params import InferenceParams
from utils import result_processor
from modules.loader.module_loader import GenericModuleLoader
from tqdm import tqdm
from PIL import Image, ImageFilter
from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio
import numpy as np
from safetensors.torch import load_file as load_safetensors
import math
from einops import repeat, rearrange
from torchvision.transforms import ToTensor
from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder
import PIL
from modules.params.vfi import VFIParams
from modules.params.i2v_enhance import I2VEnhanceParams
from typing import List,Union
from models.diffusion.wrappers import StreamingWrapper
from diffusion_trainer.abstract_trainer import AbstractTrainer
from utils.loader import download_ckpt
import torchvision.transforms.functional as TF
from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler
from transformers import BlipProcessor, BlipForConditionalGeneration

class StreamingSVD(AbstractTrainer):
    def __init__(self,
                 module_loader: GenericModuleLoader,
                 diff_trainer_params: DiffusionTrainerParams,
                 inference_params: InferenceParams,
                 vfi: VFIParams,
                 i2v_enhance: I2VEnhanceParams,
                 ):
        super().__init__(inference_params=inference_params,
                         diff_trainer_params=diff_trainer_params,
                         module_loader=module_loader,
                         )

        # network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore
        # this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules
        del self.network_config  
        self.diff_trainer_params: DiffusionTrainerParams
        self.vfi = vfi
        self.i2v_enhance = i2v_enhance
            
    def on_inference_epoch_start(self):
        super().on_inference_epoch_start()

        # for StreamingSVD we use a model wrapper that combines the base SVD model and the control model.  
        self.inference_model = StreamingWrapper(
            diffusion_model=self.model.diffusion_model,
            controlnet=self.controlnet,
            num_frame_conditioning=self.inference_params.num_conditional_frames
        )
    
    def post_init(self):
        self.svd_pipeline.set_progress_bar_config(disable=True) 
        if self.device.type != "cpu":
            self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index)

        # re-use the open clip already loaded for image conditioner for image_encoder_apm
        embedders = self.conditioner.embedders
        for embedder in embedders:
            if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise":
                self.image_encoder_apm = embedder.open_clip
        self.first_stage_model.to("cpu")
        self.conditioner.embedders[3].encoder.to("cpu")
        self.conditioner.embedders[0].open_clip.to("cpu")

        pipe = AutoPipelineForInpainting.from_pretrained(
            'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False)
        
        pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
        pipe = pipe.to(self.device)
        pipe.enable_model_cpu_offload(gpu_id = self.device.index)
        self.inpaint_pipe = pipe

        processor = BlipProcessor.from_pretrained(
            "Salesforce/blip-image-captioning-large")


        model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device)
        def blip(x): return processor.decode(model.generate(** processor(x,
                                                                         return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True)
        self.blip = blip
        
    # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
    def get_unique_embedder_keys_from_conditioner(self, conditioner):
        return list(set([x.input_key for x in conditioner.embedders]))


    # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
    def get_batch_sgm(self, keys, value_dict, N, T, device):
        batch = {}
        batch_uc = {}

        for key in keys:
            if key == "fps_id":
                batch[key] = (
                    torch.tensor([value_dict["fps_id"]])
                    .to(device)
                    .repeat(int(math.prod(N)))
                )
            elif key == "motion_bucket_id":
                batch[key] = (
                    torch.tensor([value_dict["motion_bucket_id"]])
                    .to(device)
                    .repeat(int(math.prod(N)))
                )
            elif key == "cond_aug":
                batch[key] = repeat(
                    torch.tensor([value_dict["cond_aug"]]).to(device),
                    "1 -> b",
                    b=math.prod(N),
                )
            elif key == "cond_frames":
                batch[key] = repeat(value_dict["cond_frames"],
                                    "1 ... -> b ...", b=N[0])
            elif key == "cond_frames_without_noise":
                batch[key] = repeat(
                    value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
                )
            else:
                batch[key] = value_dict[key]

        if T is not None:
            batch["num_video_frames"] = T

        for key in batch.keys():
            if key not in batch_uc and isinstance(batch[key], torch.Tensor):
                batch_uc[key] = torch.clone(batch[key])
        return batch, batch_uc
    
    # Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py
    @torch.no_grad()
    def decode_first_stage(self, z):
        self.first_stage_model.to(self.device)

        z = 1.0 / self.diff_trainer_params.scale_factor * z
        #n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
        n_samples = min(z.shape[0],8)
        #print("SVD decoder started")
        import time
        start = time.time()
        n_rounds = math.ceil(z.shape[0] / n_samples)
        all_out = []
        with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast):
            for n in range(n_rounds):
                if isinstance(self.first_stage_model.decoder, VideoDecoder):
                    kwargs = {"timesteps": len(
                        z[n * n_samples: (n + 1) * n_samples])}
                else:
                    kwargs = {}
                out = self.first_stage_model.decode(
                    z[n * n_samples: (n + 1) * n_samples], **kwargs
                )
                all_out.append(out)
        out = torch.cat(all_out, dim=0)
        # print(f"SVD decoder finished after {time.time()-start} seconds.")
        self.first_stage_model.to("cpu")
        return out
    

    # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
    def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params):
        C = 4
        F = 8 # spatial compression TODO read from model
   
        H = svd_input_frame.shape[-2]
        W = svd_input_frame.shape[-1]
        num_frames = self.sampler.guider.num_frames

        shape = (num_frames, C, H // F, W // F)
        batch_size = 1

        image = svd_input_frame[None,:]
        cond_aug = 0.02

        value_dict = {}
        value_dict["motion_bucket_id"] = 127
        value_dict["fps_id"] = 6
        value_dict["cond_aug"] = cond_aug
        value_dict["cond_frames_without_noise"] = image
        value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image)

        batch, batch_uc = self.get_batch_sgm(
            self.get_unique_embedder_keys_from_conditioner(
                self.conditioner),
            value_dict,
            [1, num_frames],
            T=num_frames,
            device=self.device,
        )

        self.conditioner.embedders[3].encoder.to(self.device)
        self.conditioner.embedders[0].open_clip.to(self.device)
        c, uc = self.conditioner.get_unconditional_conditioning(
            batch,
            batch_uc=batch_uc,
            force_uc_zero_embeddings=[
                "cond_frames",
                "cond_frames_without_noise",
            ],
        )
        self.conditioner.embedders[3].encoder.to("cpu")
        self.conditioner.embedders[0].open_clip.to("cpu")


        for k in ["crossattn", "concat"]:
            uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
            uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
            c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
            c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)

        randn = torch.randn(shape, device=self.device)

        additional_model_inputs = {}
        additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device)
        additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

        # StreamingSVD inputs
        additional_model_inputs["batch_size"] = 2*batch_size
        additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames
        additional_model_inputs["ctrl_frames"] = params["ctrl_frames"]

        self.inference_model.diffusion_model = self.inference_model.diffusion_model.to(
            self.device)
        self.inference_model.controlnet = self.inference_model.controlnet.to(
            self.device)

        c["vector"] = c["vector"].to(randn.dtype)
        uc["vector"] = uc["vector"].to(randn.dtype)
        def denoiser(input, sigma, c):
            return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs)
        samples_z = self.sampler(denoiser,randn,cond=c,uc=uc)

        self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu")
        self.inference_model.controlnet = self.inference_model.controlnet.to("cpu")
        samples_x = self.decode_first_stage(samples_z)
        
        samples = torch.clamp(samples_x,min=-1.0,max=1.0)
        return samples
        

    def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams):
        """
        Extracts anchor frames from the input video based on the provided inference parameters.

        Parameters:
        - video: torch.Tensor
            The input video tensor.
        - input_range: list
            The pixel value range of input video.
        - inference_params: InferenceParams
            An object containing inference parameters.
            - anchor_frames: str
                Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame,
                or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames.

        Returns:
        - torch.Tensor
            The extracted anchor frames from the input video.
        """
        video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1])

        if video.shape[1] == 3 and video.shape[0]>3:
            video = rearrange(video,"F C W H -> 1 F C W H")
        elif video.shape[0]>3 and video.shape[-1] == 3:
            video = rearrange(video,"F W H C -> 1 F C W H")
        else:
            raise NotImplementedError(f"Unexpected video input format: {video.shape}")

        if ":" in inference_params.anchor_frames:        
            anchor_frames = inference_params.anchor_frames.split(":")
            anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames]
            assert len(anchor_frames) == 2,"Anchor frames encoding wrong."
            anchor = video[:,anchor_frames[0]:anchor_frames[1]]
        else:
            anchor_frame = int(inference_params.anchor_frames)
            anchor = video[:, anchor_frame].unsqueeze(0)

        return anchor
    
    def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams):
        """
        Extracts control frames from the input video.

        Parameters:
        - video: torch.Tensor
            The input video tensor.
        - input_range: list
            The pixel value range of input video.
        - inference_params: InferenceParams
            An object containing inference parameters.

        Returns:
        - torch.Tensor
            The extracted control image encoding frames from the input video.
        """
        video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1])
        if video.shape[1] == 3 and video.shape[0] > 3:
            video = rearrange(video, "F C W H -> 1 F C W H")
        elif video.shape[0] > 3 and video.shape[-1] == 3:
            video = rearrange(video, "F W H C -> 1 F C W H")
        else:
            raise NotImplementedError(
                f"Unexpected video input format: {video.shape}")
        
        # return the last num_conditional_frames frames
        video = video[:, -inference_params.num_conditional_frames:]
        return video


    def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams): 
        """
        Perform autoregressive generation of video chunks based on the initial generation and inference parameters.

        Parameters:
        - initial_generation: torch.Tensor or list of torch.Tensor
            The initial generation or list of initial generation video chunks.
        - inference_params: InferenceParams
            An object containing inference parameters.

        Returns:
        - torch.Tensor
            The generated video resulting from autoregressive generation.
        """

        # input is [-1,1] float
        result_chunks = initial_generation
        if not isinstance(result_chunks,list):
            result_chunks = [result_chunks]
        
        # make sure 
        if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3):
            result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")]

        # generating chunk by conditioning on the previous chunks
        for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"):
            
            # extract anchor frames based on the entire, so far generated, video
            # note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD).
            anchor_frames = self.extract_anchor_frames(
                video = torch.cat(result_chunks), 
                inference_params=inference_params, 
                input_range=[-1, 1],
                )
            
            # extract control frames based on the last generated chunk
            ctrl_frames = self.extract_ctrl_frames(
                video = result_chunks[-1],
                input_range=[-1, 1],
                inference_params=inference_params,
                )

            # select the anchor frame for svd
            svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)]
                 
            # generate the next chunk
            # result is [F, C, H, W], range is [-1,1] float.
            result = self._generate_conditional_output(
                                                      svd_input_frame = svd_input_frame,
                                                      inference_params=inference_params,
                                                      anchor_frames=anchor_frames,
                                                      ctrl_frames=ctrl_frames,
                                                      )

            # from each generation, we keep all frames except for the first <num_conditional_frames> frames
            result = result[inference_params.num_conditional_frames:]
            result_chunks.append(result) 
            torch.cuda.empty_cache()

        # concat all chunks to one long video
        result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks]
        result = result_processor.concat_chunks(result_chunks)
        torch.cuda.empty_cache()
        return result

    def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9):

        if source_image.width / source_image.height == target_aspect_ratio:
            return source_image, None
        
        image = source_image.copy().convert("RGBA")
        mask = image.split()[-1]
        image = image.convert("RGB")
        padding = get_padding_for_aspect_ratio(image)


        mask_padded = TF.pad(mask, padding)
        mask_padded_size = mask_padded.size
        mask_padded_resized = TF.resize(mask_padded, (512, 512),
                                        interpolation=TF.InterpolationMode.NEAREST)
        mask_padded_resized = TF.invert(mask_padded_resized)

        # image
        padded_input_image = TF.pad(image, padding, padding_mode="reflect")
        resized_image = TF.resize(padded_input_image, (512, 512))

        image_tensor = (self.inpaint_pipe.image_processor.preprocess(
            resized_image).cuda().half())
        latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None)
        self.inpaint_pipe.scheduler.set_timesteps(999)
        noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise(
            latent_tensor,
            torch.randn_like(latent_tensor),
            self.inpaint_pipe.scheduler.timesteps[:1],
        )

        prompt = self.blip(source_image)
        if prompt.startswith("there is "):
            prompt = prompt[len("there is "):]

        output_image_normalized_size = self.inpaint_pipe(
            prompt=prompt,
            image=resized_image,
            mask_image=mask_padded_resized,
            latents=noisy_latent_tensor,
        ).images[0]

        output_image_extended_size = TF.resize(
            output_image_normalized_size, mask_padded_size[::-1])

        blured_outpainting_mask = TF.invert(mask_padded).filter(
            ImageFilter.GaussianBlur(radius=5))

        final_image = Image.composite(
            output_image_extended_size, padded_input_image, blured_outpainting_mask)
        return final_image, TF.invert(mask_padded)


    def image_to_video(self, batch, inference_params: InferenceParams, batch_idx):

        """
        Performs image to video based on the input batch and inference parameters.
        It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD.

        Parameters:
        - batch: dict
            The input batch containing the start image for generating the video.
        - inference_params: InferenceParams
            An object containing inference parameters.
        - batch_idx: int
            The index of the batch.

        Returns:
        - torch.Tensor
            The generated video based on the image image.
        """
        batch_key = "image"
        assert batch_key == "image", f"Generating video from {batch_key} not implemented."
        input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy())
        # TODO remove conversion forth and back

        outpainted_image, _ = self.ensure_image_ratio(input_image)

        #image = Image.fromarray(np.uint8(image))
        '''
        if image.width/image.height != 16/9:
            print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.")
        '''
        scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image)
        assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}."
        
        # Generating first chunk
        with torch.autocast(device_type="cuda",enabled=False):
            video_chunks = self.svd_pipeline(
                scaled_outpainted_image, decode_chunk_size=8).frames[0]

        video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks])
        video_chunks = video_chunks * 2.0 - 1 # [-1,1], float

        video_chunks = video_chunks.to(self.device)
    
        video = self._autoregressive_generation(
                                                initial_generation=video_chunks,
                                                inference_params=inference_params)

        return video, scaled_outpainted_image, expanded_size


    def generate_output(self, batch, batch_idx,inference_params: InferenceParams):
        """
        Generate output video based on the input batch and inference parameters.

        Parameters:
        - batch: dict
            The input batch containing data for generating the output video.
        - batch_idx: int
            The index of the batch.
        - inference_params: InferenceParams
            An object containing inference parameters.

        Returns:
        - torch.Tensor
            The generated video. Note the result is also accessible via self.trainer.generated_video
        """

        sample_id = batch["sample_id"].item()
        video, scaled_outpainted_image, expanded_size = self.image_to_video(
            batch, inference_params=inference_params, batch_idx=sample_id)
        
        self.trainer.generated_video = video.numpy()
        self.trainer.expanded_size = expanded_size
        self.trainer.scaled_outpainted_image = scaled_outpainted_image
        return video