import math

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel, MarianTokenizer

from fudge.constants import *
from fudge.util import pad_mask
from fudge.clickbait_classifier import BertClickbaitClassifier, ClickbaitConfig

class Model(nn.Module):
    def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True):
        super(Model, self).__init__()

#         self.topic = args.task == 'topic'
        self.formality = args.task == 'formality'
        self.iambic = args.task == 'iambic'
        self.rhyme = args.task == 'rhyme'
        self.newline = args.task == 'newline'
        self.clickbait = args.task == 'clickbait'
#         if self.topic:
#             self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
#             if glove_embeddings is None:
#                 if verbose:
#                     print('initializing word embeddings from scratch')
#                 self.word_embed = nn.Embedding(vocab_size, GLOVE_DIM, padding_idx=0)
#             else:
#                 if verbose:
#                     print('initializing word embeddings from glove')
#                 self.word_embed = nn.Embedding.from_pretrained(glove_embeddings, padding_idx=0)
#             self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
#             self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
#             large_hidden_dim = HIDDEN_DIM
#             self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
#             self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
#             self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
#             self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
#             self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
#             self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
#             self.nonlinear = nn.ReLU()
#         elif self.formality:
        if self.formality:
            self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is ''
            self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions
            self.out_linear = nn.Linear(HIDDEN_DIM, 1)
        elif self.iambic:
            self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id)
            self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0) # want it to be causal so we can learn all positions
            self.out_linear = nn.Linear(HIDDEN_DIM, 1)
        elif self.rhyme:
            self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
            self.word_embed = nn.Embedding(rhyme_group_size+1, GLOVE_DIM, padding_idx=0) # this embedding for future words will actually embed the rhyme group idx
            self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
            self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
            large_hidden_dim = HIDDEN_DIM + COUNT_SYLLABLE_DIM
            self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
            self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
            self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
            self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
            self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
            self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
            self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
            self.nonlinear = nn.ReLU()
        elif self.newline:
            self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
            self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False)
            self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
            self.out_linear = nn.Linear(HIDDEN_DIM + COUNT_SYLLABLE_DIM, HIDDEN_DIM)
            self.out_linear2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
            self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
            self.nonlinear = nn.ReLU()
        elif self.clickbait:
            # mpnet_config = ClickbaitConfig(
            #     model_type="mpnet",
            #     pretrained_model="sentence-transformers/all-mpnet-base-v2",
            #     num_labels=1,
            #     dropout=0.2,
            #     inner_dim1=256,
            #     inner_dim2=32, 
            #     max_length=25,
            #     load_pretrained=True,
            #     freeze_bert=False,
            # )
            #TODO add a checkpoint to Classifier
            # print('add a checkpoint to Classifier')
            checkpoint = args.checkpoint #'ckpt/clickbait_classifier/checkpoint-1464'
            # self.classifier = BertClickbaitClassifier(config=mpnet_config).to(torch.device(args.device))
            self.classifier = BertClickbaitClassifier.from_pretrained(checkpoint).to(torch.device(args.device))
        else:
            raise NotImplementedError # TODO honestly this can/should be refactored into different models


    def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False, attention_mask=None):
        """
        inputs: token ids, batch x seq, right-padded with 0s
        lengths: lengths of inputs; batch
        future_words: batch x N words to check if not predict next token, else batch
        log_probs: N
        syllables_to_go: batch
        """
#         if self.topic:
#             inputs = self.gpt_embed(inputs) # batch x seq x 300
#             inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
#             rnn_output, _ = self.rnn(inputs)
#             rnn_output, _ = pad_packed_sequence(rnn_output)
#             rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
#             hidden = rnn_output
#             attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
#             embed = self.word_embed(future_words) # batch x N x 300
#             embed_query = self.embed_key_linear(embed)
#             attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
#             attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
#             attention_weights = attention_weights * attention_mask.unsqueeze(2)
#             hidden = self.attention_value_linear(hidden)
#             weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
#             unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
#             unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2)
#             unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
#             unnormalized_scores = self.out_linear3(unnormalized_scores)
#             scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0) 
#             return scores # batch x N of normalized scores or batch x 
#         elif self.formality:
        if self.formality:
            inputs = self.marian_embed(inputs)
            inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
            rnn_output, _ = self.rnn(inputs)
            rnn_output, _ = pad_packed_sequence(rnn_output)
            rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
            return self.out_linear(rnn_output).squeeze(2)
        elif self.iambic:
            inputs = self.gpt_embed(inputs)
            inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
            rnn_output, _ = self.rnn(inputs)
            rnn_output, _ = pad_packed_sequence(rnn_output)
            rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
            return self.out_linear(rnn_output).squeeze(2)
        elif self.rhyme:
            inputs = self.gpt_embed(inputs) # batch x seq x 300
            inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
            rnn_output, _ = self.rnn(inputs)
            rnn_output, _ = pad_packed_sequence(rnn_output)
            rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
            hidden = rnn_output
            attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
            embed = self.word_embed(future_words) # batch x N x 300
            embedded_syllables_to_go = self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100
            auxiliary_embed = embedded_syllables_to_go
            embed_query = self.embed_key_linear(torch.cat([embed, auxiliary_embed], dim=2))
            attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
            attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
            attention_weights = attention_weights * attention_mask.unsqueeze(2)
            hidden = self.attention_value_linear(hidden)
            weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
            unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
            unnormalized_scores = torch.cat([unnormalized_scores, embed, auxiliary_embed], dim=2)
            unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
            unnormalized_scores = self.out_linear3(unnormalized_scores)
            scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0) 
            return scores # batch x N of normalized scores or batch x 
        elif self.newline:
            inputs = self.gpt_embed(inputs) # batch x seq x 300
            inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
            rnn_output, _ = self.rnn(inputs)
            rnn_output, _ = pad_packed_sequence(rnn_output)
            rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
            hidden = torch.cat([rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, rnn_output.shape[1], -1)], dim=2)
            return self.out_linear3(self.nonlinear(self.out_linear2(self.nonlinear(self.out_linear(hidden))))).squeeze(2)
        elif self.clickbait:

            input_ids = torch.tensor(inputs)
            classifer_output = self.classifier(input_ids = input_ids, attention_mask = attention_mask).logits

            classifer_output = classifer_output[None,:,:] # batch x seq x 300
            # return self.out_linear(rnn_output).squeeze(2)
            return classifer_output.squeeze(2)

        else: 
            raise NotImplementedError