"""
summarize - a module for summarizing text using a model from the Hugging Face model hub
"""
import logging
import pprint as pp

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")

import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from utils import validate_pytorch2


def load_model_and_tokenizer(model_name: str) -> tuple:
    """
    load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub

    :param str model_name: the model name/ID on the hub
    :return tuple: a tuple containing the model and tokenizer
    """
    MODEL_OPTIONS = {
        "Text Summarizer": "pszemraj/long-t5-tglobal-base-16384-book-summary",
        "News Article Summarizer Alpha": "pszemraj/long-t5-tglobal-base-sci-simplify",
        "News Article Summarizer Beta": "pszemraj/long-t5-tglobal-base-sci-simplify-elife",
        "Scientific Document Summarizer Alpha": "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
        "Scientific Document Summarizer Beta": "pszemraj/pegasus-x-large-book-summary",
    }
    selected_model_identifier = MODEL_OPTIONS.get(model_name)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForSeq2SeqLM.from_pretrained(
        selected_model_identifier,
    ).to(device)
    model = model.eval()

    tokenizer = AutoTokenizer.from_pretrained(selected_model_identifier)

    logging.info(f"Loaded model {selected_model_identifier} to {device}")

    if validate_pytorch2():
        try:
            logging.info("Compiling model with Torch 2.0")
            model = torch.compile(model)
        except Exception as e:
            logging.warning(f"Could not compile model with Torch 2.0: {e}")
    else:
        logging.info("Torch 2.0 not detected, skipping compilation")

    return model, tokenizer


def summarize_and_score(
    ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs
) -> tuple:
    """
    summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary

    Args:
        ids (): the batch of ids
        mask (): the attention mask for the batch
        model   (): the model to use for summarization
        tokenizer (): the tokenizer to use for summarization
        is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
        **kwargs: any additional arguments to pass to the model
    Returns:
        tuple (str, float): the summary,  the score for the summary
    """

    ids = ids[None, :]
    mask = mask[None, :]

    input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
    attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask

    global_attention_mask = torch.zeros_like(attention_mask)
    # put global attention on <s> token
    global_attention_mask[:, 0] = 1

    if is_general_attention_model:
        summary_pred_ids = model.generate(
            input_ids,
            attention_mask=attention_mask,
            output_scores=True,
            return_dict_in_generate=True,
            **kwargs,
        )
    else:
        summary_pred_ids = model.generate(
            input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
            output_scores=True,
            return_dict_in_generate=True,
            **kwargs,
        )
    summary = tokenizer.batch_decode(
        summary_pred_ids.sequences,
        skip_special_tokens=True,
        remove_invalid_values=True,
    )
    score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)

    return summary, score


def summarize_via_tokenbatches(
    input_text: str,
    model,
    tokenizer,
    batch_length=2048,
    batch_stride=16,
    min_batch_length=512,
    **kwargs,
) -> list:
    """
    summarize_via_tokenbatches - summarize a long string via batches of tokens

    Args:
        input_text (str): the text to summarize
        model (): the model to use for summarization
        tokenizer (): the tokenizer to use for summarization
        batch_length (int, optional): the length of each batch. Defaults to 2048.
        batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
        min_batch_length (int, optional): the minimum length of each batch. Defaults to 512.

        **kwargs: any additional arguments to pass to the model for inference
    Returns:
        list: a list of dictionaries containing the input tokens, the summary, and the summary score
    """

    logger = logging.getLogger(__name__)
    # log all input parameters
    if batch_length < min_batch_length:
        logger.warning(
            f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}"
        )
        batch_length = min_batch_length

    logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
    logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")

    encoded_input = tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=batch_length,
        stride=batch_stride,
        return_overflowing_tokens=True,
        add_special_tokens=False,
        return_tensors="pt",
    )

    in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
    gen_summaries = []

    pbar = tqdm(total=len(in_id_arr))

    for _id, _mask in zip(in_id_arr, att_arr):
        result, score = summarize_and_score(
            ids=_id,
            mask=_mask,
            model=model,
            tokenizer=tokenizer,
            **kwargs,
        )
        score = round(float(score), 4)
        _sum = {
            "input_tokens": _id,
            "summary": result,
            "summary_score": score,
        }
        gen_summaries.append(_sum)
        logger.debug(f"Score for batch: {score}. num chars: {len(repr(result))}")
        logger.debug(f"Summary:\n\t{result}")
        pbar.update()

    pbar.close()

    return gen_summaries