from typing import List, Iterator, cast

import copy
import numpy as np

import torch as T
from torch import nn
from torch.nn import functional as F
from transformers import BertConfig, BertModel
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

class Diacritizer(nn.Module):
    def __init__(
            self,
            config,
            device=None,
            load_pretrained=True
    ) -> None:
        super().__init__()
        self._dummy = nn.Parameter(T.ones(1))

        if 'modeling' in config:
            config = config['modeling']
        self.config = config

        model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        if load_pretrained:
            self.token_model: BertModel = AutoModel.from_pretrained(model_name)
        else:
            marbert_config = AutoConfig.from_pretrained(model_name)
            self.token_model = AutoModel.from_config(marbert_config)

        self.num_classes  = 15
        self.diac_model_config = BertConfig(**config['diac_model_config'])
        self.token_model_config: BertConfig = self.token_model.config

        self.char_embs      = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"])
        self.diac_emb_model = self.build_diac_model(self.token_model)

        self.down_project_token_embeds_deep = None
        self.down_project_token_embeds = None
        if 'token_hidden_size' in config:
            if config['token_hidden_size'] == 'auto':
                down_proj_size = self.diac_emb_model.config.hidden_size
            else:
                down_proj_size = config['token_hidden_size']
            if config.get('deep-down-proj', False):
                self.down_project_token_embeds_deep = nn.Sequential(
                    nn.Linear(
                        self.token_model_config.hidden_size + config["char-embed-dim"],
                        down_proj_size * 4,
                        bias=False,
                    ),
                    nn.Tanh(),
                    nn.Linear(
                        down_proj_size * 4,
                        down_proj_size,
                        bias=False,
                    )
                )
            # else:
                self.down_project_token_embeds = nn.Linear(
        self.token_model_config.hidden_size + config["char-embed-dim"],
                    down_proj_size,
                    bias=False,
                )

        # assert self.down_project_token_embeds_deep is None or self.down_project_token_embeds is None
        classifier_feature_size = self.diac_model_config.hidden_size
        if config.get('deep-cls', False):
            # classifier_feature_size = 512
            self.final_feature_transform = nn.Linear(
                self.diac_model_config.hidden_size
                    + self.token_model_config.hidden_size,
                #^ diac_features + [residual from token_model]
                out_features=classifier_feature_size,
                bias=False
            )
        else:
            self.final_feature_transform = None

        self.feature_layer_norm = nn.LayerNorm(classifier_feature_size)
        self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True)

        self.trim_model_(config)

        self.dropout = nn.Dropout(config['dropout'])
        self.sent_dropout_p = config['sentence_dropout']
        self.closs = F.cross_entropy

    def build_diac_model(self, token_model=None):
        if self.config.get('pre-init-diac-model', False):
            model = copy.deepcopy(self.token_model)
            model.pooler = None
            model.embeddings.word_embeddings = None

            num_layers = self.config.get('keep-token-model-layers', None)
            model.encoder.layer = nn.ModuleList(
                list(model.encoder.layer[num_layers:num_layers*2])
            )

            model.encoder.config.num_hidden_layers = num_layers
        else:
            model = BertModel(self.diac_model_config)
        return model

    def trim_model_(self, config):
        self.token_model.pooler = None
        self.diac_emb_model.pooler = None
        # self.diac_emb_model.embeddings = None
        self.diac_emb_model.embeddings.word_embeddings = None

        num_token_model_kept_layers = config.get('keep-token-model-layers', None)
        if num_token_model_kept_layers is not None:
            self.token_model.encoder.layer = nn.ModuleList(
                list(self.token_model.encoder.layer[:num_token_model_kept_layers])
            )
            self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers

        if not config.get('full-finetune', False):
            for param in self.token_model.parameters():
                param.requires_grad = False
            finetune_last_layers = config.get('num-finetune-last-layers', 4)
            if finetune_last_layers > 0:
                unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:]
                for layer in unfrozen_layers:
                    for param in layer.parameters():
                        param.requires_grad = True

    def get_grouped_params(self):
        downstream_params: Iterator[nn.Parameter] = cast(
            Iterator,
            (param
                for module in (self.diac_emb_model, self.classifier, self.char_embs)
                for param in module.parameters())
        )
        pg = {
            'pretrained': self.token_model.parameters(),
            'downstream': downstream_params,
        }
        return pg

    @property
    def device(self):
        return self._dummy.device

    def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None):
        # ^ word_x, char_x, diac_x are Indices
        # ^ xt             : self.preprocess((word_x, char_x, diac_x)),
        # ^ yt             : T.tensor(diac_y, dtype=T.long),
        # ^ subword_lengths: T.tensor(subword_lengths, dtype=T.long)
        #< Move char_x, diac_x to device because they're small and trainable
        xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths)
        xt[0] = xt[0].to(self.device)
        xt[1] = xt[1].to(self.device)
        # xt[2] = xt[2].to(self.device)

        yt = yt.to(self.device)
        #^ yt: [b tw tc]

        Nb, Tword, Tchar = xt[1].shape
        if Tword * Tchar < 500:
            diac = self(*xt, subword_lengths)
            loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
        else:
            num_chunks = Tword * Tchar / 300
            loss = 0
            for i in range(round(num_chunks+0.5)):
                _slice = slice(i*300, (i+1)*300)
                chunk = self._slice_batch(xt, _slice)
                diac = self(*chunk, subword_lengths[_slice])
                chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
                loss = loss + chunk_loss

        return loss

    def _slice_batch(self, xt: List[T.Tensor], _slice):
        return [xt[0][_slice], xt[1][_slice], xt[2][_slice]]

    def _slim_batch_size(
            self,
            tx: T.Tensor,
            cx: T.Tensor,
            yt: T.Tensor,
            subword_lengths: T.Tensor
    ):
        #^ tx : [b tt]
        #^ cx : [b tw tc]
        #^ yt : [b tw tc]
        token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id)
        Ttoken = token_nonpad_mask.sum(1).max()
        tx = tx[:, :Ttoken]

        char_nonpad_mask = cx.ne(0)
        Tword = char_nonpad_mask.any(2).sum(1).max()
        Tchar = char_nonpad_mask.sum(2).max()
        cx = cx[:, :Tword, :Tchar]
        yt = yt[:, :Tword, :Tchar]
        subword_lengths = subword_lengths[:, :Tword]

        return tx, cx, yt, subword_lengths

    def token_dropout(self, toke_x):
        #^ toke_x : [b tw]
        if self.training:
            q = 1.0 - self.sent_dropout_p
            sdo = T.bernoulli(T.full(toke_x.shape, q))
            toke_x[sdo == 0] = self.tokenizer.pad_token_id
        return toke_x

    def sentence_dropout(self, word_embs: T.Tensor):
        #^ word_embs : [b tw dwe]
        if self.training:
            q = 1.0 - self.sent_dropout_p
            sdo = T.bernoulli(T.full(word_embs.shape[:2], q))
            sdo = sdo.detach().unsqueeze(-1).to(word_embs)
            word_embs = word_embs * sdo
            # toke_x[sdo == 0] = self.tokenizer.pad_token_id
        return word_embs

    def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor):
        y: BaseModelOutputWithPoolingAndCrossAttentions
        y = self.token_model(input_ids, attention_mask=attention_mask)
        z = y.last_hidden_state
        return z

    def forward(
            self,
            toke_x          : T.Tensor,
            char_x          : T.Tensor,
            diac_x          : T.Tensor,
            subword_lengths : T.Tensor,
    ):
        #^ toke_x : [b tt]
        #^ char_x : [b tw tc]
        #^ diac_x/labels : [b tw tc]
        #^ subword_lengths : [b, tw]
        # !TODO Use `subword_lengths` to aggregate subword embeddings first before ...
        # ... passing concatenated contextual embedding to chars in diac_model

        token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id)
        char_nonpad_mask = char_x.ne(0)

        Nb, Tw, Tc = char_x.shape
        # assert Tw == Tw_0 and Tc == Tc_0, f"{Tw=} {Tw_0=}, {Tc=} {Tc_0=}"

        # toke_x = self.token_dropout(toke_x)
        token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask)
        # token_embs = self.sentence_dropout(token_embs)
        #? Strip BOS,EOS
        token_embs = token_embs[:, 1:-1, ...]

        sent_word_strides = subword_lengths.cumsum(1)
        sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs)
        for i_b in range(Nb):
            token_embs_ib = token_embs[i_b]
            start_iw = 0
            for i_word, end_iw in enumerate(sent_word_strides[i_b]):
                if end_iw == start_iw: break
                word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw)
                sent_enc[i_b, i_word] = word_emb
                start_iw = end_iw
        #^ sent_enc: [b tw dwe]

        char_x_flat = char_x.reshape(Nb*Tw, Tc)
        char_nonpad_mask = char_x_flat.gt(0)
        # ^ char_nonpad_mask [b*tw tc]

        char_x_flat = char_x_flat * char_nonpad_mask

        cembs = self.char_embs(char_x_flat)
    
        #^ cembs: [b*tw tc dce]
        wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1)
        #^ wembs: [b tw dwe] => [b tw _ dwe] => [b*tw tc dwe]
        cw_embs = T.cat([cembs, wembs], dim=-1)
        #^ char_embs : [b*tw tc dcw] ; dcw = dc + dwe
        cw_embs = self.dropout(cw_embs)

        cw_embs_ = cw_embs
        if self.down_project_token_embeds is not None:
            cw_embs_ = self.down_project_token_embeds(cw_embs)
        if self.down_project_token_embeds_deep is not None:
            cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs)
        cw_embs = cw_embs_

        diac_enc: BaseModelOutputWithPoolingAndCrossAttentions
        diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask)
        diac_emb = diac_enc.last_hidden_state
        diac_emb = self.dropout(diac_emb)
        #^ diac_emb: [b*tw tc dce]
        diac_emb = diac_emb.view(Nb, Tw, Tc, -1)

        sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1)
        final_feature = T.cat([sent_residual, diac_emb], dim=-1)
        if self.final_feature_transform is not None:
            final_feature = self.final_feature_transform(final_feature)
            final_feature = F.tanh(final_feature)
            final_feature = self.dropout(final_feature)
        else:
            final_feature = diac_emb

        # final_feature = self.feature_layer_norm(final_feature)
        diac_out = self.classifier(final_feature)
        # if T.isnan(diac_out).any():
        #     breakpoint()
        return diac_out

    def predict(self, dataloader):
        from tqdm import tqdm
        import diac_utils as du
        training = self.training
        self.eval()

        preds = {'haraka': [], 'shadda': [], 'tanween': []}
        print("> Predicting...")
        for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)):
            inputs[0] = inputs[0].to(self.device)
            inputs[1] = inputs[1].to(self.device)
            output = self(*inputs, subword_lengths).detach()

            marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1)
            #^ [b ts tw]

            haraka, tanween, shadda = du.flat_2_3head(marks)

            preds['haraka'].extend(haraka)
            preds['tanween'].extend(tanween)
            preds['shadda'].extend(shadda)

        self.train(training)
        return (
            np.array(preds['haraka']),
            np.array(preds["tanween"]),
            np.array(preds["shadda"]),
        )

if __name__ == "__main__":
    model = Diacritizer({
        "num-chars": 36,
        "hidden_size": 768,
        "char-embed-dim": 32,
        "dropout": 0.25,
        "sentence_dropout": 0.2,
        "diac_model_config": {
            "num_layers": 4,
            "hidden_size": 768 + 32,
            "intermediate_size": (768 + 32) * 4,
        },
    }, load_pretrained=False)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(model)
    print(f"{trainable_params:,}/{total_params:,} Trainable Parameters")