from typing import Any
from diffusers import LCMScheduler
import torch
from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
import numpy as np
from constants import DEVICE
from backend.models.lcmdiffusion_setting import LCMLora
from backend.device import is_openvino_device
from backend.openvino.pipelines import (
    get_ov_text_to_image_pipeline,
    ov_load_taesd,
    get_ov_image_to_image_pipeline,
)
from backend.pipelines.lcm import (
    get_lcm_model_pipeline,
    load_taesd,
    get_image_to_image_pipeline,
)
from backend.pipelines.lcm_lora import get_lcm_lora_pipeline
from backend.models.lcmdiffusion_setting import DiffusionTask
from image_ops import resize_pil_image
from math import ceil


class LCMTextToImage:
    def __init__(
        self,
        device: str = "cpu",
    ) -> None:
        self.pipeline = None
        self.use_openvino = False
        self.device = ""
        self.previous_model_id = None
        self.previous_use_tae_sd = False
        self.previous_use_lcm_lora = False
        self.previous_ov_model_id = ""
        self.previous_safety_checker = False
        self.previous_use_openvino = False
        self.img_to_img_pipeline = None
        self.is_openvino_init = False
        self.torch_data_type = (
            torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16
        )
        print(f"Torch datatype : {self.torch_data_type}")

    def _pipeline_to_device(self):
        print(f"Pipeline device : {DEVICE}")
        print(f"Pipeline dtype : {self.torch_data_type}")
        self.pipeline.to(
            torch_device=DEVICE,
            torch_dtype=self.torch_data_type,
        )

    def _add_freeu(self):
        pipeline_class = self.pipeline.__class__.__name__
        if isinstance(self.pipeline.scheduler, LCMScheduler):
            if pipeline_class == "StableDiffusionPipeline":
                print("Add FreeU - SD")
                self.pipeline.enable_freeu(
                    s1=0.9,
                    s2=0.2,
                    b1=1.2,
                    b2=1.4,
                )
            elif pipeline_class == "StableDiffusionXLPipeline":
                print("Add FreeU - SDXL")
                self.pipeline.enable_freeu(
                    s1=0.6,
                    s2=0.4,
                    b1=1.1,
                    b2=1.2,
                )

    def _update_lcm_scheduler_params(self):
        if isinstance(self.pipeline.scheduler, LCMScheduler):
            self.pipeline.scheduler = LCMScheduler.from_config(
                self.pipeline.scheduler.config,
                beta_start=0.001,
                beta_end=0.01,
            )

    def init(
        self,
        device: str = "cpu",
        lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(),
    ) -> None:
        self.device = device
        self.use_openvino = lcm_diffusion_setting.use_openvino
        model_id = lcm_diffusion_setting.lcm_model_id
        use_local_model = lcm_diffusion_setting.use_offline_model
        use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder
        use_lora = lcm_diffusion_setting.use_lcm_lora
        lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora
        ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id

        if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
            lcm_diffusion_setting.init_image = resize_pil_image(
                lcm_diffusion_setting.init_image,
                lcm_diffusion_setting.image_width,
                lcm_diffusion_setting.image_height,
            )

        if (
            self.pipeline is None
            or self.previous_model_id != model_id
            or self.previous_use_tae_sd != use_tiny_auto_encoder
            or self.previous_lcm_lora_base_id != lcm_lora.base_model_id
            or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id
            or self.previous_use_lcm_lora != use_lora
            or self.previous_ov_model_id != ov_model_id
            or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker
            or self.previous_use_openvino != lcm_diffusion_setting.use_openvino
        ):
            if self.use_openvino and is_openvino_device():
                if self.pipeline:
                    del self.pipeline
                    self.pipeline = None
                self.is_openvino_init = True
                if (
                    lcm_diffusion_setting.diffusion_task
                    == DiffusionTask.text_to_image.value
                ):
                    print(f"***** Init Text to image (OpenVINO) - {ov_model_id} *****")
                    self.pipeline = get_ov_text_to_image_pipeline(
                        ov_model_id,
                        use_local_model,
                    )
                elif (
                    lcm_diffusion_setting.diffusion_task
                    == DiffusionTask.image_to_image.value
                ):
                    print(f"***** Image to image (OpenVINO) - {ov_model_id} *****")
                    self.pipeline = get_ov_image_to_image_pipeline(
                        ov_model_id,
                        use_local_model,
                    )
            else:
                if self.pipeline:
                    del self.pipeline
                    self.pipeline = None
                if self.img_to_img_pipeline:
                    del self.img_to_img_pipeline
                    self.img_to_img_pipeline = None

                if use_lora:
                    print(
                        f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****"
                    )
                    self.pipeline = get_lcm_lora_pipeline(
                        lcm_lora.base_model_id,
                        lcm_lora.lcm_lora_id,
                        use_local_model,
                        torch_data_type=self.torch_data_type,
                    )
                else:
                    print(f"***** Init LCM Model pipeline - {model_id} *****")
                    self.pipeline = get_lcm_model_pipeline(
                        model_id,
                        use_local_model,
                    )

                if (
                    lcm_diffusion_setting.diffusion_task
                    == DiffusionTask.image_to_image.value
                ):
                    self.img_to_img_pipeline = get_image_to_image_pipeline(
                        self.pipeline
                    )
                self._pipeline_to_device()

            if use_tiny_auto_encoder:
                if self.use_openvino and is_openvino_device():
                    print("Using Tiny Auto Encoder (OpenVINO)")
                    ov_load_taesd(
                        self.pipeline,
                        use_local_model,
                    )
                else:
                    print("Using Tiny Auto Encoder")
                    if (
                        lcm_diffusion_setting.diffusion_task
                        == DiffusionTask.text_to_image.value
                    ):
                        load_taesd(
                            self.pipeline,
                            use_local_model,
                            self.torch_data_type,
                        )
                    elif (
                        lcm_diffusion_setting.diffusion_task
                        == DiffusionTask.image_to_image.value
                    ):
                        load_taesd(
                            self.img_to_img_pipeline,
                            use_local_model,
                            self.torch_data_type,
                        )

            if (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.image_to_image.value
                and lcm_diffusion_setting.use_openvino
            ):
                self.pipeline.scheduler = LCMScheduler.from_config(
                    self.pipeline.scheduler.config,
                )
            else:
                self._update_lcm_scheduler_params()

            if use_lora:
                self._add_freeu()

            self.previous_model_id = model_id
            self.previous_ov_model_id = ov_model_id
            self.previous_use_tae_sd = use_tiny_auto_encoder
            self.previous_lcm_lora_base_id = lcm_lora.base_model_id
            self.previous_lcm_lora_id = lcm_lora.lcm_lora_id
            self.previous_use_lcm_lora = use_lora
            self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker
            self.previous_use_openvino = lcm_diffusion_setting.use_openvino
            if (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.text_to_image.value
            ):
                print(f"Pipeline : {self.pipeline}")
            elif (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.image_to_image.value
            ):
                if self.use_openvino and is_openvino_device():
                    print(f"Pipeline : {self.pipeline}")
                else:
                    print(f"Pipeline : {self.img_to_img_pipeline}")

    def generate(
        self,
        lcm_diffusion_setting: LCMDiffusionSetting,
        reshape: bool = False,
    ) -> Any:
        guidance_scale = lcm_diffusion_setting.guidance_scale
        img_to_img_inference_steps = lcm_diffusion_setting.inference_steps
        check_step_value = int(
            lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength
        )
        if (
            lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value
            and check_step_value < 1
        ):
            img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength)
            print(
                f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}"
            )

        if lcm_diffusion_setting.use_seed:
            cur_seed = lcm_diffusion_setting.seed
            if self.use_openvino:
                np.random.seed(cur_seed)
            else:
                torch.manual_seed(cur_seed)

        is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device()
        if is_openvino_pipe:
            print("Using OpenVINO")
            if reshape and not self.is_openvino_init:
                print("Reshape and compile")
                self.pipeline.reshape(
                    batch_size=-1,
                    height=lcm_diffusion_setting.image_height,
                    width=lcm_diffusion_setting.image_width,
                    num_images_per_prompt=lcm_diffusion_setting.number_of_images,
                )
                self.pipeline.compile()

            if self.is_openvino_init:
                self.is_openvino_init = False

        if not lcm_diffusion_setting.use_safety_checker:
            self.pipeline.safety_checker = None
            if (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.image_to_image.value
                and not is_openvino_pipe
            ):
                self.img_to_img_pipeline.safety_checker = None

        if (
            not lcm_diffusion_setting.use_lcm_lora
            and not lcm_diffusion_setting.use_openvino
            and lcm_diffusion_setting.guidance_scale != 1.0
        ):
            print("Not using LCM-LoRA so setting guidance_scale 1.0")
            guidance_scale = 1.0

        if lcm_diffusion_setting.use_openvino:
            if (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.text_to_image.value
            ):
                result_images = self.pipeline(
                    prompt=lcm_diffusion_setting.prompt,
                    negative_prompt=lcm_diffusion_setting.negative_prompt,
                    num_inference_steps=lcm_diffusion_setting.inference_steps,
                    guidance_scale=guidance_scale,
                    width=lcm_diffusion_setting.image_width,
                    height=lcm_diffusion_setting.image_height,
                    num_images_per_prompt=lcm_diffusion_setting.number_of_images,
                ).images
            elif (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.image_to_image.value
            ):
                result_images = self.pipeline(
                    image=lcm_diffusion_setting.init_image,
                    strength=lcm_diffusion_setting.strength,
                    prompt=lcm_diffusion_setting.prompt,
                    negative_prompt=lcm_diffusion_setting.negative_prompt,
                    num_inference_steps=img_to_img_inference_steps * 3,
                    guidance_scale=guidance_scale,
                    num_images_per_prompt=lcm_diffusion_setting.number_of_images,
                ).images

        else:
            if (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.text_to_image.value
            ):
                result_images = self.pipeline(
                    prompt=lcm_diffusion_setting.prompt,
                    negative_prompt=lcm_diffusion_setting.negative_prompt,
                    num_inference_steps=lcm_diffusion_setting.inference_steps,
                    guidance_scale=guidance_scale,
                    width=lcm_diffusion_setting.image_width,
                    height=lcm_diffusion_setting.image_height,
                    num_images_per_prompt=lcm_diffusion_setting.number_of_images,
                ).images
            elif (
                lcm_diffusion_setting.diffusion_task
                == DiffusionTask.image_to_image.value
            ):
                result_images = self.img_to_img_pipeline(
                    image=lcm_diffusion_setting.init_image,
                    strength=lcm_diffusion_setting.strength,
                    prompt=lcm_diffusion_setting.prompt,
                    negative_prompt=lcm_diffusion_setting.negative_prompt,
                    num_inference_steps=img_to_img_inference_steps,
                    guidance_scale=guidance_scale,
                    width=lcm_diffusion_setting.image_width,
                    height=lcm_diffusion_setting.image_height,
                    num_images_per_prompt=lcm_diffusion_setting.number_of_images,
                ).images

        return result_images