import torch
import torch.nn as nn

from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from transformers import PretrainedConfig, PreTrainedModel

class SEPath(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(SEPath, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, out_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, in_tensor, out_tensor):
        B, C, H, W = in_tensor.size()
        # Squeeze operation
        x = in_tensor.view(B, C, -1).mean(dim=2)
        # Excitation operation
        x = self.fc(x).unsqueeze(2).unsqueeze(2)

        return out_tensor * x

class SeResVaeConfig(PretrainedConfig):
    model_type = "seresvae"
    def __init__(
        self,
        base_model="stabilityai/stable-diffusion-2-1",
        height=512,
        width=512,
        **kwargs
    ):
        self.base_model=base_model
        self.height=height
        self.width=width
        super().__init__(**kwargs)

class SeResVaeModel(PreTrainedModel):
    config_class = SeResVaeConfig
    def __init__(self, config):
        super().__init__(config)
        self.image_processor = VaeImageProcessor()
        self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
        self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
        self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
        self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
        self.height=config.height
        self.width=config.width

    def forward(self, images_gray, input_type='pil', output_type='pil'):
        if input_type=='pil':
            images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float()
        elif input_type=='pt':
            images_gray=images_gray
        else:
            raise ValueError('unsupported input_type')
        images_gray = images_gray.to(self.vae.device)
        B, C, H, W = images_gray.shape
        prompt_embeds = self.prompt_embeds.repeat(B,1,1)

        posterior, encode_residual = self.encode_with_residual(images_gray)
        latents = posterior.mode()
        t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
        noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
        denoised_latents = latents - noise_pred
        images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
        
        if output_type=='pil':
            images_rgb = self.image_processor.postprocess(images_rgb)
        elif output_type=='np':
            images_rgb = self.image_processor.postprocess(images_rgb, 'np')
        elif output_type=='pt':
            images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
        elif output_type=='none':
            images_rgb = images_rgb
        else:
            raise ValueError('unsupported output_type')

        return images_rgb

    def encode_with_residual(self, sample):
        re = self.vae.encoder.conv_in(sample)
        re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
        re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
        re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
        re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
        rem = self.vae.encoder.mid_block(re3)
        re_out = self.vae.encoder.conv_norm_out(rem)
        re_out = self.vae.encoder.conv_act(re_out)
        re_out = self.vae.encoder.conv_out(re_out)
        re_out = self.vae.quant_conv(re_out)
        
        posterior = DiagonalGaussianDistribution(re_out)
        return posterior, (re0_out, re1_out, re2_out, rem, re_out)

    def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
        rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
        rd = self.vae.decoder.conv_in(rd)
        rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
        rd0 = self.vae.decoder.up_blocks[0](rdm)
        rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
        rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
        rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
        rd_out = self.vae.decoder.conv_norm_out(rd3)
        rd_out = self.vae.decoder.conv_act(rd_out)
        sample_out = self.vae.decoder.conv_out(rd_out)
        return sample_out

    def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
        for resnet in down_encoder_block_2d.resnets:
            hidden_states = resnet(hidden_states, temb=None)
    
        output_states = hidden_states
        if down_encoder_block_2d.downsamplers is not None:
            for downsampler in down_encoder_block_2d.downsamplers:
                hidden_states = downsampler(hidden_states)
    
        return hidden_states, output_states