from dataset import ParallelTextReader
from torch.utils.data import DataLoader
from accelerate.memory_utils import find_executable_batch_size
from datasets import load_metric
from tqdm import tqdm
import torch
import json
import argparse
import numpy as np


def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
    """
    Returns a dataloader for the given files.
    """

    def collate_fn(batch):
        return list(map(list, zip(*batch)))

    reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
    dataloader = DataLoader(reader, batch_size=batch_size, collate_fn=collate_fn)
    return dataloader


def eval_files(
    pred_path: str,
    gold_path: str,
    bert_score_model: str,
    starting_batch_size: int = 128,
    output_path: str = None,
):
    """
    Evaluates the given files.
    """
    if torch.cuda.is_available():
        device = "cuda:0"
        print("We will use a GPU to calculate BertScore.")
    else:
        device = "cpu"
        print(
            f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
        )

    dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
    print("Loading sacrebleu...")
    sacrebleu = load_metric("sacrebleu")
    print("Loading rouge...")
    rouge = load_metric("rouge")
    print("Loading bleu...")
    bleu = load_metric("bleu")
    print("Loading meteor...")
    meteor = load_metric("meteor")
    print("Loading ter...")
    ter = load_metric("ter")
    print("Loading BertScore...")
    bert_score = load_metric("bertscore")

    with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
        for predictions, references in dataloader:
            sacrebleu.add_batch(predictions=predictions, references=references)
            rouge.add_batch(predictions=predictions, references=references)
            bleu.add_batch(
                predictions=[p.split() for p in predictions],
                references=[[r[0].split()] for r in references],
            )
            meteor.add_batch(predictions=predictions, references=references)
            ter.add_batch(predictions=predictions, references=references)
            bert_score.add_batch(predictions=predictions, references=references)
            pbar.update(len(predictions))

    result_dictionary = {}
    print(f"Computing sacrebleu")
    result_dictionary["sacrebleu"] = sacrebleu.compute()
    print(f"Computing rouge score")
    result_dictionary["rouge"] = rouge.compute()
    print(f"Computing bleu score")
    result_dictionary["bleu"] = bleu.compute()
    print(f"Computing meteor score")
    result_dictionary["meteor"] = meteor.compute()
    print(f"Computing ter score")
    result_dictionary["ter"] = ter.compute()

    @find_executable_batch_size(starting_batch_size=starting_batch_size)
    def inference(batch_size):
        nonlocal bert_score, bert_score_model
        print(f"Computing bert score with batch size {batch_size} on {device}")
        results = bert_score.compute(
            model_type=bert_score_model,
            batch_size=batch_size,
            device=device,
            use_fast_tokenizer=True,
        )

        results["precision"] = np.average(results["precision"])
        results["recall"] = np.average(results["recall"])
        results["f1"] = np.average(results["f1"])

        return results

    result_dictionary["bert_score"] = inference()

    if output_path is not None:
        with open(output_path, "w") as f:
            json.dump(result_dictionary, f, indent=4)

    print(f"Results: {json.dumps(result_dictionary,indent=4)}")

    return result_dictionary


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the translation evaluation experiments"
    )
    parser.add_argument(
        "--pred_path",
        type=str,
        required=True,
        help="Path to a txt file containing the predicted sentences.",
    )

    parser.add_argument(
        "--gold_path",
        type=str,
        required=True,
        help="Path to a txt file containing the gold sentences.",
    )

    parser.add_argument(
        "--starting_batch_size",
        type=int,
        default=64,
        help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
    )

    parser.add_argument(
        "--output_path",
        type=str,
        default=None,
        help="Path to a json file to save the results. If not given, the results will be printed to the console.",
    )

    parser.add_argument(
        "--bert_score_model",
        type=str,
        default="microsoft/deberta-xlarge-mnli",
        help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
        "and https://github.com/Tiiiger/bert_score for more details.",
    )

    args = parser.parse_args()

    eval_files(
        pred_path=args.pred_path,
        gold_path=args.gold_path,
        starting_batch_size=args.starting_batch_size,
        output_path=args.output_path,
        bert_score_model=args.bert_score_model,
    )