from transformers import GPT2Config, AutoTokenizer, GPT2Config
from transformers import PretrainedConfig, PreTrainedModel
import transformers
from typing import Optional, Tuple, Callable
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
from .utils import CABlock, _GPT2LMHeadModel
from .configuration_prot2text import Prot2TextConfig
import os
import numpy as np
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList

from .pdb2graph import PDB2Graph, download_alphafold_structure
from .graphs import *
from .utils_dataset import *

try:
    from graphein.protein.config import ProteinGraphConfig, DSSPConfig
    from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor
    from graphein.protein.features.nodes.dssp import  phi, psi, asa, rsa, secondary_structure
    from graphein.protein.edges.distance import (add_peptide_bonds,
                                                add_hydrogen_bond_interactions,
                                                add_distance_threshold,
                                                )
except ImportError:
    raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')

try:
    from torch_geometric.nn import RGCNConv, global_mean_pool
except ImportError:  
    raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')



class EncoderRGCN(PreTrainedModel):
    '''
    This class implement the RGCN encoder to encode the protein structure
    '''
    def __init__(self, input_dim, hidden_dim=512, n_layers=6, emb_dim=512, dropout=0.2, num_relation=7, prot2text_version='1.0'):
        super(EncoderRGCN, self).__init__(PretrainedConfig(name='RGCN'))
        self.n_layers = n_layers
        self.output_dim = emb_dim
        self.prot2text_version = prot2text_version

        self.fc0 = nn.Linear(input_dim, hidden_dim)
        self.batchnorm_final = nn.BatchNorm1d(hidden_dim)
        
        self.batch_norms = nn.ModuleList()
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        lst = list()
        
        lst.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relation))
            
        for i in range(n_layers-1):
            lst.append(RGCNConv(hidden_dim,hidden_dim, num_relations=num_relation))

        self.conv = nn.ModuleList(lst)
      
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, self.output_dim)
      
        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.LeakyReLU()
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.main_input_name = 'nothing'

    def forward(self, x:Optional[torch.FloatTensor] = None, 
                edge_index:Optional[torch.LongTensor] = None,
                edge_type:Optional[torch.LongTensor] = None,
                batch:Optional[torch.LongTensor] = None,
                **kargs):
        #construct pyg edge index shape (2, num_edges) from edge_list
        x = self.relu(self.fc0(x))
        
        for i in range(self.n_layers):
            x = self.conv[i](x, edge_index, edge_type)

        out = global_mean_pool(x, batch)
        out = self.relu(self.fc1(out))
        out = self.relu(self.fc2(out))
        
        return out.unsqueeze(1)

class Prot2TextModel(PreTrainedModel):
    config_class = Prot2TextConfig
    _keys_to_ignore_on_load_missing = [r"transformer"]
    base_model_prefix = "decoder"
    def __init__(self, config):
        super().__init__(config)

        self.gpt_config = GPT2Config.from_dict(config.gpt_config)
        
        # if we are using RGCN to encode the protein's structure, define the RGCN encoder
        if config.rgcn:
            self.encoder = EncoderRGCN(input_dim=config.rgcn_input_dim, hidden_dim=self.gpt_config.n_embd, n_layers=config.rgcn_n_layers, emb_dim=self.gpt_config.n_embd, prot2text_version=self.config.prot2text_version)

        # define the GPT2 decoder
        self.decoder = _GPT2LMHeadModel(self.gpt_config)

        # if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer
        if config.esm:
            self.esm_config = PretrainedConfig.from_dict(config.esm_config)
            self.esm = transformers.EsmModel(self.esm_config)
            self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd)
            if config.cross_esm_graph and config.rgcn:
                self.h = nn.ModuleList([CABlock(self.gpt_config,  layer_idx=i) for i in range(4)])
                self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon)
            
        self.config = config
        
        
    def get_encoder(self):
        return self.encoder
        
    def get_decoder(self):
        return self.decoder

    def get_input_embeddings(self):
        if hasattr(self, "transformer"):
            return self.transformer.wte
        return self.decoder.transformer.wte
    
    def warm_up(self, gpt_model=None, esm_model=None):
        if esm_model is not None:
            self.esm = transformers.EsmModel.from_pretrained(esm_model)
        if gpt_model is not None:    
            self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False)
            self.decoder.resize_token_embeddings(self.gpt_config.vocab_size)
            self.decoder.config = self.gpt_config
                
        
    def forward(self,
                encoder_input_ids: Optional[torch.LongTensor] = None,
                edge_index: Optional[torch.LongTensor] = None,
                batch: Optional[torch.LongTensor] = None,
                x: Optional[torch.FloatTensor] = None,
                edge_type: Optional[torch.LongTensor] = None,
                decoder_input_ids: Optional[torch.LongTensor] = None,
                past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
                past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None,
                decoder_attention_mask: Optional[torch.FloatTensor] = None,
                attention_mask: Optional[torch.FloatTensor] = None,
                token_type_ids: Optional[torch.LongTensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                head_mask: Optional[torch.FloatTensor] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None,
                encoder_hidden_states: Optional[torch.Tensor] = None,
                encoder_attention_mask: Optional[torch.FloatTensor] = None,
                labels: Optional[torch.LongTensor] = None,
                use_cache: Optional[bool] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
                get_graph_emb: Optional[bool] = False,
                **delete_args,
            ):
        use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache
        return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict
        
        
        if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
            decoder_input_ids = decoder_input_ids.squeeze(0) 

        if x is not None and self.config.rgcn:
            graph_emb = self.encoder(x, edge_index, edge_type, batch)
            graph_mask = None
            
        if self.config.esm:
            if self.config.prot2text_version=='1.0':
                if encoder_input_ids.size()[1] != 1021:
                    raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021")
            
            esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
            esm_emb = self.to_embedding(esm_emb)
            if not self.config.cross_esm_graph and self.config.rgcn:
                graph_emb = torch.cat((graph_emb, esm_emb), dim=1)
                t_add = torch.ones((attention_mask.size(0), 1)).to(attention_mask.get_device())
                attention_mask = torch.cat((t_add, attention_mask), dim=1) 
            elif self.config.cross_esm_graph and self.config.rgcn:
                if past_key_values_graph_esm is None:
                    past_length = 0
                    past_key_values_graph_esm = tuple([None] * len(self.h))
                else:
                    past_length = past_key_values_graph_esm[0][0].size(-2) 
                output_shape = esm_emb.size()
                
                all_self_attentions = () if output_attentions else None
                all_cross_attentions = () if output_attentions and self.gpt_config.add_cross_attention else None
                all_hidden_states = () if output_hidden_states else None
                for i, (block, layer_past) in enumerate(zip(self.h, past_key_values_graph_esm)):
                    outputs = block(
                        esm_emb,
                        layer_past=layer_past,
                        attention_mask=attention_mask,
                        encoder_hidden_states=graph_emb,
                        encoder_attention_mask=graph_mask,
                        use_cache=use_cache,
                        output_attentions=False,
                    )
                    esm_emb = outputs[0]

                esm_emb = self.ln_f(esm_emb)
                esm_emb = esm_emb.view(output_shape)  
                graph_emb = esm_emb
            else:
                graph_emb = esm_emb
        else:
            attention_mask = None
        if self.config.prot2text_version=='1.0':
            attention_mask = None
        if get_graph_emb:
            return graph_emb
    
        transformer_outputs = self.decoder(input_ids=decoder_input_ids,
                                            past_key_values=past_key_values,
                                            attention_mask=decoder_attention_mask,
                                            token_type_ids=token_type_ids,
                                            position_ids=position_ids,
                                            head_mask=head_mask,
                                            inputs_embeds=inputs_embeds,
                                            encoder_hidden_states=graph_emb,
                                            encoder_attention_mask=attention_mask,
                                            labels=labels,
                                            use_cache=use_cache,
                                            output_attentions=output_attentions,
                                            output_hidden_states=output_hidden_states,
                                            return_dict=return_dict,
                                            )
        
        return transformer_outputs
    
    @torch.no_grad()    
    def generate_protein_description(self,
                                    protein_pdbID=None, 
                                    protein_sequence=None,
                                    edge_index: Optional[torch.LongTensor] = None,
                                    x: Optional[torch.FloatTensor] = None,
                                    edge_type: Optional[torch.LongTensor] = None,
                                    tokenizer=None,
                                    device='cpu'
                                     ):
        
        if self.config.esm and not self.config.rgcn and protein_sequence==None:
            raise ValueError(
                "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
            )
        if self.config.rgcn and protein_pdbID==None and (x==None or edge_index==None or edge_type==None):
            raise ValueError(
                "The model you are trying to use is based on protein structure, please provide a AlphaFold ID (you must have to have internet connection using protein_pdbID, or provide the triplet inputs: x (node features), edge_index and edge_type"
            )
        if self.config.esm:
            esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
        
        if protein_pdbID==None and protein_sequence==None:
            raise ValueError(
                "you need to provide either a protein AlphaFold Id or an amino-acid sequence"
            )
            
        if protein_pdbID!=None:
            config = {"node_metadata_functions": [amino_acid_one_hot, 
                                                expasy_protein_scale,
                                                meiler_embedding,
                                                hydrogen_bond_acceptor, hydrogen_bond_donor
                                                ],
                    "edge_construction_functions": [add_peptide_bonds,
                                                    add_hydrogen_bond_interactions,
                                                    partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.),],
                    "graph_metadata_functions":[asa,phi, psi, secondary_structure, rsa],
                    "dssp_config": DSSPConfig()}
            config = ProteinGraphConfig(**config)

            PATH_TO_DATA = f"~/.tmp/pdb/pdb"
            OUTPUT_FOLDER = f"~/.tmp/pdb/raw"
            save_dir = f"~/.tmp/pdb/"
            isExist = os.path.exists(PATH_TO_DATA)
            if not isExist:
                os.makedirs(PATH_TO_DATA)
            isExist = os.path.exists(OUTPUT_FOLDER)
            if not isExist:
                os.makedirs(OUTPUT_FOLDER)
            isExist = os.path.exists(save_dir+'processed')
            if not isExist:
                os.makedirs(save_dir+'processed')
            
            structure_filename = download_alphafold_structure(uniprot_id=protein_pdbID, out_dir=PATH_TO_DATA)
            if structure_filename is None:
                raise ValueError("Error! the ID does not exist in AlphaFoldDB or you do not have internet connection")
            graph_filename = structure_filename.split('/')
            graph_filename[-2] = 'raw'
            graph_filename[-1] = graph_filename[-1].replace('.pdb', '.pt')
            graph_filename = '/'.join(graph_filename)
            process_filename = structure_filename.split('/')
            process_filename[-2] = 'processed'
            process_filename[-1] = process_filename[-1].replace('.pdb', '.pt')
            process_filename = '/'.join(process_filename)    
            try:            
                gpdb = PDB2Graph(root = PATH_TO_DATA, output_folder = OUTPUT_FOLDER, config=config, n_processors=1).create_pyg_graph(structure_filename)
                seq = esmtokenizer(gpdb.sequence, add_special_tokens=True, truncation=True, max_length=1021, padding='max_length',return_tensors="pt") #
                torch.save(gpdb, graph_filename)
                gpdb.edge_type = [np.array(gpdb.edge_type.transpose(0,1))]
                gpdb.encoder_input_ids = seq['input_ids']
                gpdb.attention_mask = seq['attention_mask']
                torch.save(gpdb, process_filename)
            except:
                os.remove(structure_filename)
                raise ValueError('creating graphs did not work, probably the pdb file of alphaFold is damaged')
            
            self.eval()
            inputs = gpdb
            inputs = inputs.to_dict()
            
            inputs['edge_type'] =  torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0)
            inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1)
            for key in ['num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates']:
                inputs.pop(key)
            inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
            inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
            inputs["decoder_attention_mask"] = torch.ones(inputs['decoder_input_ids'].shape[0], 1)
            self.to(device)
            inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
            encoder_state = dict()
            encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
            encoder_state['attentions'] = inputs['attention_mask']
            for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
                inputs.pop(key)
            tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'], 
                                            encoder_outputs=encoder_state, 
                                            use_cache=True, 
                                            output_attentions=False, 
                                            output_scores=False, 
                                            return_dict_in_generate=True, 
                                            encoder_attention_mask=inputs['attention_mask'], 
                                            length_penalty=1.0,
                                            no_repeat_ngram_size=None,
                                            early_stopping=False,
                                            num_beams=1)

            generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)

            os.remove(structure_filename)
            os.remove(graph_filename)
            os.remove(process_filename)        
                
            return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
            
        else:
            seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
            inputs={}
            inputs['encoder_input_ids'] = seq['input_ids']
            inputs['attention_mask'] = seq['attention_mask']
            inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
            inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
            
            self.to(device)
            inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
            encoder_state = dict()
            encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
            generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
            
            return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
    
    @torch.no_grad()
    def generate(self,
                inputs: Optional[torch.Tensor] = None,
                generation_config: Optional[GenerationConfig] = None,
                logits_processor: Optional[LogitsProcessorList] = None,
                stopping_criteria: Optional[StoppingCriteriaList] = None,
                prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
                synced_gpus: Optional[bool] = None,
                assistant_model: Optional["PreTrainedModel"] = None,
                streamer: Optional["BaseStreamer"] = None,
                **kwargs,
            ):
        encoder_state = self(**kwargs, get_graph_emb=True)
        input_ids = kwargs['decoder_input_ids']
        attention_mask = kwargs['decoder_attention_mask']
        kwargs['encoder_attention_mask'] = kwargs['attention_mask']
        if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm:
            t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device())
            kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1) 
        for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length',
                    '_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]:
            if key in kwargs.keys():
                kwargs.pop(key)
        return self.decoder.generate(input_ids=input_ids,
                                     generation_config=generation_config,
                                     logits_processor=logits_processor,
                                     stopping_criteria=stopping_criteria,
                                     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                                     synced_gpus=synced_gpus,
                                     assistant_model=assistant_model,
                                     streamer=streamer,
                                     encoder_outputs={'hidden_states': encoder_state, 'attentions':0},
                                     **kwargs
                                     )