import argparse
import logging
import time
import gradio as gr
import torch
from transformers import pipeline

from utils import make_mailto_form, postprocess, clear, make_email_link

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

use_gpu = torch.cuda.is_available()


def generate_text(
    prompt: str,
    gen_length=64,
    penalty_alpha=0.6,
    top_k=6,
    no_repeat_ngram_size=2,
    length_penalty=1.0,
    # perma params (not set by user)
    abs_max_length=512,
    verbose=False,
):
    """
    generate_text - generate text from a prompt using a text generation pipeline

    Args:
        prompt (str): the prompt to generate text from
        model_input (_type_): the text generation pipeline
        max_length (int, optional): the maximum length of the generated text. Defaults to 128.
        method (str, optional): the generation method. Defaults to "Sampling".
        verbose (bool, optional): the verbosity of the output. Defaults to False.

    Returns:
        str: the generated text
    """
    global generator
    if verbose:
        logging.info(f"Generating text from prompt:\n\n{prompt}")
        logging.info(
            f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}"
        )
    st = time.perf_counter()

    input_tokens = generator.tokenizer(prompt)
    input_len = len(input_tokens["input_ids"])
    if input_len > abs_max_length:
        logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
    result = generator(
        prompt,
        max_length=gen_length + input_len,
        min_length=input_len + 4,
        penalty_alpha=penalty_alpha,
        top_k=top_k,
        no_repeat_ngram_size=no_repeat_ngram_size,
        length_penalty=length_penalty,
    )  # generate
    response = result[0]["generated_text"]
    rt = time.perf_counter() - st
    if verbose:
        logging.info(f"Generated text: {response}")
    rt_string = f"Generation time: {rt:.2f}s"
    logging.info(rt_string)

    formatted_email = postprocess(response)
    return make_mailto_form(body=formatted_email), formatted_email


def load_emailgen_model(model_tag: str):
    """
    load_emailgen_model - load a text generation pipeline for email generation

    Args:
        model_tag (str): the huggingface model tag to load

    Returns:
        transformers.pipelines.TextGenerationPipeline: the text generation pipeline
    """
    global generator
    generator = pipeline(
        "text-generation",
        model_tag,
        device=0 if use_gpu else -1,
    )


def get_parser():
    """
    get_parser - a helper function for the argparse module
    """
    parser = argparse.ArgumentParser(
        description="Text Generation demo for postbot",
    )

    parser.add_argument(
        "-m",
        "--model",
        required=False,
        type=str,
        default="postbot/distilgpt2-emailgen-V2",
        help="Pass an different huggingface model tag to use a custom model",
    )

    parser.add_argument(
        "-v",
        "--verbose",
        required=False,
        action="store_true",
        help="Verbose output",
    )

    parser.add_argument(
        "-a",
        "--penalty_alpha",
        type=float,
        default=0.6,
        help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6",
    )

    parser.add_argument(
        "-k",
        "--top_k",
        type=int,
        default=6,
        help="The top k for the text generation pipeline (contrastive search) - default 6",
    )
    return parser


default_prompt = """
Hello,

Following up on last week's bubblegum shipment, I"""

available_models = [
    "postbot/distilgpt2-emailgen-V2",
    "postbot/distilgpt2-emailgen",
    "postbot/gpt2-medium-emailgen",
]

if __name__ == "__main__":

    logging.info("\n\n\nStarting new instance of app.py")
    args = get_parser().parse_args()
    logging.info(f"received args:\t{args}")
    model_tag = args.model
    verbose = args.verbose
    top_k = args.top_k
    alpha = args.penalty_alpha

    assert top_k > 0, "top_k must be greater than 0"
    assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1"

    logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
    generator = pipeline(
        "text-generation",
        model_tag,
        device=0 if use_gpu else -1,
    )

    demo = gr.Blocks()

    logging.info("launching interface...")

    with demo:
        gr.Markdown("# Auto-Complete Emails - Demo")
        gr.Markdown(
            "Enter part of an email, and a text-gen model will complete it! See details below. "
        )
        gr.Markdown("---")

        with gr.Column():

            gr.Markdown("## Generate Text")
            gr.Markdown("Edit the prompt and parameters and press **Generate**!")
            prompt_text = gr.Textbox(
                lines=4,
                label="Email Prompt",
                value=default_prompt,
            )

            with gr.Row():
                clear_button = gr.Button(
                    value="Clear Prompt",
                )
                num_gen_tokens = gr.Slider(
                    label="Generation Tokens",
                    value=32,
                    maximum=96,
                    minimum=16,
                    step=8,
                )

            generate_button = gr.Button(
                value="Generate!",
                variant="primary",
            )
            gr.Markdown("---")
            gr.Markdown("### Results")
            # put a large HTML placeholder here
            generated_email = gr.Textbox(
                label="Generated Text",
                placeholder="This is where the generated text will appear",
                interactive=False,
            )
            email_mailto_button = gr.HTML(
                "<i>a clickable email button will appear here</i>"
            )

            gr.Markdown("---")
            gr.Markdown("## Advanced Options")
            gr.Markdown(
                "This demo generates text via beam search. See details about these parameters [here](https://huggingface.co/blog/how-to-generate), otherwise they should be fine as-is."
            )
            with gr.Row():
                model_name = gr.Dropdown(
                    choices=available_models,
                    label="Choose a model",
                    value=model_tag,
                )
                load_model_button = gr.Button(
                    "Load Model",
                    variant="secondary",
                )
                no_repeat_ngram_size = gr.Radio(
                    choices=[1, 2, 3, 4],
                    label="no repeat ngram size",
                    value=3,
                )
            with gr.Row():
                contrastive_top_k = gr.Radio(
                    choices=[2, 4, 6, 8],
                    label="Top K",
                    value=top_k,
                )

                penalty_alpha = gr.Slider(
                    label="Penalty Alpha",
                    value=alpha,
                    maximum=1.0,
                    minimum=0.0,
                    step=0.1,
                )
                length_penalty = gr.Slider(
                    minimum=0.5,
                    maximum=1.0,
                    label="Length Penalty",
                    value=1.0,
                    step=0.1,
                )
            gr.Markdown("---")

        with gr.Column():

            gr.Markdown("## About")
            gr.Markdown(
                "[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage."
            )
            gr.Markdown(
                "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something."
            )
            gr.Markdown("---")

        clear_button.click(
            fn=clear,
            inputs=[prompt_text],
            outputs=[prompt_text],
        )
        generate_button.click(
            fn=generate_text,
            inputs=[
                prompt_text,
                num_gen_tokens,
                penalty_alpha,
                contrastive_top_k,
                no_repeat_ngram_size,
                length_penalty,
            ],
            outputs=[email_mailto_button, generated_email],
        )

        load_model_button.click(
            fn=load_emailgen_model,
            inputs=[model_name],
            outputs=[],
        )
    demo.launch(
        enable_queue=True,
        share=True,  # for local testing
    )