from args import args, config
from items_dataset import items_dataset
from torch.utils.data import DataLoader
from models import Model_Crf, Model_Softmax
from transformers import AutoTokenizer
from tqdm import tqdm
import prediction
import torch
import math

directory = args.SAVE_MODEL_PATH
model_name = "roberta_CRF.pt"
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
model_crf = Model_Crf(config).to(device)
model_crf.load_state_dict(
    state_dict=torch.load(directory + model_name, map_location=device)
)

model_name = "roberta_softmax.pt"
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
model_roberta = Model_Softmax(config).to(device)
model_roberta.load_state_dict(
    state_dict=torch.load(directory + model_name, map_location=device)
)


def prepare_span_data(dataset):
    for sample in dataset:
        spans = items_dataset.cal_agreement_span(
            None,
            agreement_table=sample["predict_sentence_table"],
            min_agree=1,
            max_agree=2,
        )
        sample["span_labels"] = spans
        sample["original_text"] = sample["text_a"]
        del sample["text_a"]


def rank_spans(test_loader, device, model, reverse=True):
    """Calculate each span probability by e**(word average log likelihood)"""
    model.eval()
    result = []

    for i, test_batch in enumerate(tqdm(test_loader)):
        batch_text = test_batch["batch_text"]
        input_ids = test_batch["input_ids"].to(device)
        token_type_ids = test_batch["token_type_ids"].to(device)
        attention_mask = test_batch["attention_mask"].to(device)
        labels = test_batch["labels"]
        crf_mask = test_batch["crf_mask"].to(device)
        sample_mapping = test_batch["overflow_to_sample_mapping"]
        output = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=None,
            crf_mask=crf_mask,
        )
        output = torch.nn.functional.softmax(output[0], dim=-1)

        # make result of every sample
        sample_id = 0
        sample_result = {
            "original_text": test_batch["batch_text"][sample_id],
            "span_ranked": [],
        }
        for batch_id in range(len(sample_mapping)):
            change_sample = False

            # make sure status
            if sample_id != sample_mapping[batch_id]:
                change_sample = True
            if change_sample:
                sample_id = sample_mapping[batch_id]
                result.append(sample_result)
                sample_result = {
                    "original_text": test_batch["batch_text"][sample_id],
                    "span_ranked": [],
                }

            encoded_spans = items_dataset.cal_agreement_span(
                None, agreement_table=labels[batch_id], min_agree=1, max_agree=2
            )
            # print(encoded_spans)
            for encoded_span in encoded_spans:
                # calculate span loss
                span_lenght = encoded_span[1] - encoded_span[0]
                # print(span_lenght)
                span_prob_table = torch.log(
                    output[batch_id][encoded_span[0] : encoded_span[1]]
                )
                if (
                    not change_sample and encoded_span[0] == 0 and batch_id != 0
                ):  # span cross two tensors
                    span_loss += span_prob_table[0][1]  # Begin
                else:
                    span_loss = span_prob_table[0][1]  # Begin
                for token_id in range(1, span_prob_table.shape[0]):
                    span_loss += span_prob_table[token_id][2]  # Inside
                span_loss /= span_lenght

                # span decode
                decode_start = test_batch[batch_id].token_to_chars(encoded_span[0] + 1)[
                    0
                ]
                decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0] + 1
                # print((decode_start, decode_end))
                span_text = test_batch["batch_text"][sample_mapping[batch_id]][
                    decode_start:decode_end
                ]
                if (
                    not change_sample and encoded_span[0] == 0 and batch_id != 0
                ):  # span cross two tensors
                    presample = sample_result["span_ranked"].pop(-1)
                    sample_result["span_ranked"].append(
                        [presample[0] + span_text, math.e ** float(span_loss)]
                    )
                else:
                    sample_result["span_ranked"].append(
                        [span_text, math.e ** float(span_loss)]
                    )
        result.append(sample_result)

    # sorted spans by probability
    # for sample in result:
    #     sample["span_ranked"] = sorted(
    #         sample["span_ranked"], key=lambda x: x[1], reverse=reverse
    #     )
    return result


def predict_single(text):
    input_dict = [{"span_labels": []}]
    input_dict[0]["original_text"] = text
    tokenizer = AutoTokenizer.from_pretrained(
        args.pre_model_name, add_prefix_space=True
    )
    prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)
    prediction_loader = DataLoader(
        prediction_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=prediction_dataset.collate_fn,
    )
    predict_data = prediction.test_predict(prediction_loader, device, model_crf)
    prediction.add_sentence_table(predict_data)

    prepare_span_data(predict_data)
    tokenizer = AutoTokenizer.from_pretrained(
        args.pre_model_name, add_prefix_space=True
    )
    prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)
    prediction_loader = DataLoader(
        prediction_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=prediction_dataset.collate_fn,
    )
    span_ranked = rank_spans(prediction_loader, device, model_roberta)
    # for sample in span_ranked:
    #     print(sample["original_text"])
    #     print(sample["span_ranked"])

    result = []
    sample = span_ranked[0]
    orig = sample["original_text"]
    cur = 0
    for s, score in sample["span_ranked"]:
        # print()
        # print('ORIG', repr(orig))
        # print('CCUR', repr(orig[cur:]))
        # print('SSSS', repr(s))
        # print()
        end = orig.index(s, cur)
        if cur != end:
            result.append([orig[cur:end], 0])
        result.append([s, score])
        cur = end + len(s)
    if cur < len(orig):
        result.append([orig[cur:], 0])
    return result


if __name__ == "__main__":
    s = """貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。  2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。  3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。  4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。  5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。"""
    print(predict_single(s))