import datetime
import os
import random
import re
from io import StringIO

import gradio as gr
import pandas as pd
from huggingface_hub import upload_file
from text_generation import Client

from dialogues import DialogueTemplate
from share_btn import (community_icon_html, loading_icon_html, share_btn_css,
                       share_js)

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_TOKEN = os.environ.get("API_TOKEN", None)
DIALOGUES_DATASET = "ehristoforu/dialogues"

model2endpoint = {
    "starchat-beta": "https://api-inference.huggingface.co/models/HuggingFaceH4/starchat-beta",
}
model_names = list(model2endpoint.keys())


def randomize_seed_generator():
    seed = random.randint(0, 1000000)
    return seed


def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs, model):
    buffer = StringIO()
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
    file_name = f"prompts_{timestamp}.jsonl"
    data = {"model": model, "inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}
    pd.DataFrame([data]).to_json(buffer, orient="records", lines=True)

    # Push to Hub
    upload_file(
        path_in_repo=f"{now.date()}/{now.hour}/{file_name}",
        path_or_fileobj=buffer.getvalue().encode(),
        repo_id=DIALOGUES_DATASET,
        token=HF_TOKEN,
        repo_type="dataset",
    )

    # Clean and rerun
    buffer.close()


def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
    past = []
    for data in chatbot:
        user_data, model_data = data

        if not user_data.startswith(user_name):
            user_data = user_name + user_data
        if not model_data.startswith(sep + assistant_name):
            model_data = sep + assistant_name + model_data

        past.append(user_data + model_data.rstrip() + sep)

    if not inputs.startswith(user_name):
        inputs = user_name + inputs

    total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()

    return total_inputs


def wrap_html_code(text):
    pattern = r"<.*?>"
    matches = re.findall(pattern, text)
    if len(matches) > 0:
        return f"```{text}```"
    else:
        return text


def has_no_history(chatbot, history):
    return not chatbot and not history


def generate(
    RETRY_FLAG,
    model_name,
    system_message,
    user_message,
    chatbot,
    history,
    temperature,
    top_k,
    top_p,
    max_new_tokens,
    repetition_penalty,
    do_save=True,
):
    client = Client(
        model2endpoint[model_name],
        headers={"Authorization": f"Bearer {API_TOKEN}"},
        timeout=60,
    )
    # Don't return meaningless message when the input is empty
    if not user_message:
        print("Empty input")

    if not RETRY_FLAG:
        history.append(user_message)
        seed = 42
    else:
        seed = randomize_seed_generator()

    past_messages = []
    for data in chatbot:
        user_data, model_data = data

        past_messages.extend(
            [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
        )

    if len(past_messages) < 1:
        dialogue_template = DialogueTemplate(
            system=system_message, messages=[{"role": "user", "content": user_message}]
        )
        prompt = dialogue_template.get_inference_prompt()
    else:
        dialogue_template = DialogueTemplate(
            system=system_message, messages=past_messages + [{"role": "user", "content": user_message}]
        )
        prompt = dialogue_template.get_inference_prompt()

    generate_kwargs = {
        "temperature": temperature,
        "top_k": top_k,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        truncate=4096,
        seed=seed,
        stop_sequences=["<|end|>"],
    )

    stream = client.generate_stream(
        prompt,
        **generate_kwargs,
    )

    output = ""
    for idx, response in enumerate(stream):
        if response.token.special:
            continue
        output += response.token.text
        if idx == 0:
            history.append(" " + output)
        else:
            history[-1] = output

        chat = [
            (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip()))
            for i in range(0, len(history) - 1, 2)
        ]

        # chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]

        yield chat, history, user_message, ""

    if HF_TOKEN and do_save:
        try:
            now = datetime.datetime.now()
            current_time = now.strftime("%Y-%m-%d %H:%M:%S")
            print(f"[{current_time}] Pushing prompt and completion to the Hub")
            save_inputs_and_outputs(now, prompt, output, generate_kwargs, model_name)
        except Exception as e:
            print(e)

    return chat, history, user_message, ""


examples = [
    "How can I write a Python function to generate the nth Fibonacci number?",
    "How do I get the current date using shell commands? Explain how it works.",
    "What's the meaning of life?",
    "Write a function in Javascript to reverse words in a given string.",
    "Give the following data {'Name':['Tom', 'Brad', 'Kyle', 'Jerry'], 'Age':[20, 21, 19, 18], 'Height' : [6.1, 5.9, 6.0, 6.1]}. Can you plot one graph with two subplots as columns. The first is a bar graph showing the height of each person. The second is a bargraph showing the age of each person? Draw the graph in seaborn talk mode.",
    "Create a regex to extract dates from logs",
    "How to decode JSON into a typescript object",
    "Write a list into a jsonlines file and save locally",
]


def clear_chat():
    return [], []


def delete_last_turn(chat, history):
    if chat and history:
        chat.pop(-1)
        history.pop(-1)
        history.pop(-1)
    return chat, history


def process_example(args):
    for [x, y] in generate(args):
        pass
    return [x, y]


# Regenerate response
def retry_last_answer(
    selected_model,
    system_message,
    user_message,
    chat,
    history,
    temperature,
    top_k,
    top_p,
    max_new_tokens,
    repetition_penalty,
    do_save,
):
    if chat and history:
        # Removing the previous conversation from chat
        chat.pop(-1)
        # Removing bot response from the history
        history.pop(-1)
        # Setting up a flag to capture a retry
        RETRY_FLAG = True
        # Getting last message from user
        user_message = history[-1]

    yield from generate(
        RETRY_FLAG,
        selected_model,
        system_message,
        user_message,
        chat,
        history,
        temperature,
        top_k,
        top_p,
        max_new_tokens,
        repetition_penalty,
        do_save,
    )

with gr.Blocks(analytics_enabled=False, css="style.css") as demo:

    with gr.Row():
        with gr.Column():
            gr.Image("StarChat_logo.png", elem_id="banner-image", show_label=False, show_share_button=False, show_download_button=False)
    with gr.Row():
        with gr.Column():
            gr.DuplicateButton(value='Duplicate Space for private use',
                       elem_id='duplicate-button')
    with gr.Row():
        selected_model = gr.Radio(choices=model_names, value=model_names[0], label="Current Model", interactive=False)
        
    with gr.Row():
        with gr.Column():
            output = gr.Markdown()
            chatbot = gr.Chatbot(elem_id="chat-message", label="Playground")

    with gr.Row():
        with gr.Column(scale=3):
            user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input", lines=2)
            with gr.Row():
                send_button = gr.Button("â–ļī¸ Send", elem_id="send-btn", visible=True)

                regenerate_button = gr.Button("🔄 Regenerate", elem_id="retry-btn", visible=True)

                delete_turn_button = gr.Button("↩ī¸ Delete last turn", elem_id="delete-btn", visible=True)

                clear_chat_button = gr.Button("🗑 Clear chat", elem_id="clear-btn", visible=True)

            with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
                system_message = gr.Textbox(
                    elem_id="system-message",
                    placeholder="Below is a conversation between a human user and a helpful AI coding assistant.",
                    label="System Prompt",
                    lines=2,
                )
                temperature = gr.Slider(
                    label="Temperature",
                    value=0.2,
                    minimum=0.0,
                    maximum=1.0,
                    step=0.1,
                    interactive=True,
                    info="Higher values produce more diverse outputs",
                )
                top_k = gr.Slider(
                    label="Top-k",
                    value=50,
                    minimum=0.0,
                    maximum=100,
                    step=1,
                    interactive=True,
                    info="Sample from a shortlist of top-k tokens",
                )
                top_p = gr.Slider(
                    label="Top-p (nucleus sampling)",
                    value=0.95,
                    minimum=0.0,
                    maximum=1,
                    step=0.05,
                    interactive=True,
                    info="Higher values sample more low-probability tokens",
                )
                max_new_tokens = gr.Slider(
                    label="Max new tokens",
                    value=512,
                    minimum=0,
                    maximum=1024,
                    step=4,
                    interactive=True,
                    info="The maximum numbers of new tokens",
                )
                repetition_penalty = gr.Slider(
                    label="Repetition Penalty",
                    value=1.2,
                    minimum=0.0,
                    maximum=10,
                    step=0.1,
                    interactive=True,
                    info="The parameter for repetition penalty. 1.0 means no penalty.",
                )
                do_save = gr.Checkbox(
                    value=True,
                    label="Store data",
                    info="You agree to the storage of your prompt and generated text for research and development purposes:",
                )
            # with gr.Group(elem_id="share-btn-container"):
            #     community_icon = gr.HTML(community_icon_html, visible=True)
            #     loading_icon = gr.HTML(loading_icon_html, visible=True)
            # share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
            with gr.Row():
                gr.Examples(
                    examples=examples,
                    inputs=[user_message],
                    cache_examples=False,
                    fn=process_example,
                    outputs=[output],
                )

    history = gr.State([])
    RETRY_FLAG = gr.Checkbox(value=False, visible=False)

    # To clear out "message" input textbox and use this to regenerate message
    last_user_message = gr.State("")

    user_message.submit(
        generate,
        inputs=[
            RETRY_FLAG,
            selected_model,
            system_message,
            user_message,
            chatbot,
            history,
            temperature,
            top_k,
            top_p,
            max_new_tokens,
            repetition_penalty,
            do_save,
        ],
        outputs=[chatbot, history, last_user_message, user_message],
    )

    send_button.click(
        generate,
        inputs=[
            RETRY_FLAG,
            selected_model,
            system_message,
            user_message,
            chatbot,
            history,
            temperature,
            top_k,
            top_p,
            max_new_tokens,
            repetition_penalty,
            do_save,
        ],
        outputs=[chatbot, history, last_user_message, user_message],
    )

    regenerate_button.click(
        retry_last_answer,
        inputs=[
            selected_model,
            system_message,
            user_message,
            chatbot,
            history,
            temperature,
            top_k,
            top_p,
            max_new_tokens,
            repetition_penalty,
            do_save,
        ],
        outputs=[chatbot, history, last_user_message, user_message],
    )

    delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history])
    clear_chat_button.click(clear_chat, outputs=[chatbot, history])
    selected_model.change(clear_chat, outputs=[chatbot, history])
    # share_button.click(None, [], [], _js=share_js)

demo.launch(show_api=False)