import os

import gradio as gr
from text_generation import Client

from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = " https://api-inference.huggingface.co/models/BigCode/octocoder"

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)

client = Client(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)


def generate(query: str, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ):
    if query.endswith("."):
        prompt = f"Question: {query}\n\nAnswer:"
    else:
        prompt = f"Question: {query}.\n\nAnswer:"

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    stream = client.generate_stream(prompt, **generate_kwargs)
    output = ""

    previous_token = ""
    for response in stream:
        if response.token.text == "<|endoftext|>":
            return output
        else:
            output += response.token.text
        previous_token = response.token.text
        yield output
    return output


def process_example(**krwags):
    for x in generate(**krwags):
        pass
    return x


css = ".generating {visibility: hidden}"

monospace_css = """
#q-input textarea {
    font-family: monospace, 'Consolas', Courier, monospace;
}
"""

css += share_btn_css + monospace_css

description = """
<div style="text-align: center;">
    <center><img src='https://raw.githubusercontent.com/bigcode-project/octopack/31f3320f098703c7910e43492c39366eeea68d83/banner.png' width='70%'/></center>
    <br>
    <h1><u> OctoCoder Demo </u></h1>
</div>
<br>
<div style="text-align: center;">
    <p>This is a demo to demonstrate the capabilities of <a href="https://huggingface.co/bigcode/octocoder">OctoCoder</a> model by showing how it can be used to generate code by following the instructions provided in the input.</p>
    <p><strong>OctoCoder</strong> is an instruction tuned model with 15.5B parameters created by finetuning StarCoder on CommitPackFT & OASST</p>
</div>
"""
disclaimer = """⚠️<b>Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within.</b>\
 <br>**Intended Use**: this app and its [supporting model](https://huggingface.co/bigcode) are provided for demonstration purposes; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](https://huggingface.co/bigcode)"""

examples = [
    ['Please write a function in Python that performs bubble sort.', 256],
    ['''Explain the following piece of code
def count_unique(s):
    s = s.lower()
    s_split = list(s)
    valid_chars = [char for char in s_split if char.isalpha() or char == " "]
    valid_sentence = "".join(valid_chars)
    uniques = set(valid_sentence.split(" "))
    return len(uniques)''', 512],
    [
        'Write an efficient Python function that takes a given text and returns its Morse code equivalent without using any third party library',
        512],
    ['Write a html and css code to render a clock', 8000],
]

with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
    with gr.Column():
        gr.Markdown(description)
        with gr.Row():
            with gr.Column():
                with gr.Accordion("Settings", open=True):
                    with gr.Row():
                        column_1, column_2 = gr.Column(), gr.Column()
                        with column_1:
                            temperature = gr.Slider(
                                label="Temperature",
                                value=0.2,
                                minimum=0.0,
                                maximum=1.0,
                                step=0.05,
                                interactive=True,
                                info="Higher values produce more diverse outputs",
                            )
                            max_new_tokens = gr.Slider(
                                label="Max new tokens",
                                value=256,
                                minimum=0,
                                maximum=8192,
                                step=64,
                                interactive=True,
                                info="The maximum numbers of new tokens",
                            )
                        with column_2:
                            top_p = gr.Slider(
                                label="Top-p (nucleus sampling)",
                                value=0.90,
                                minimum=0.0,
                                maximum=1,
                                step=0.05,
                                interactive=True,
                                info="Higher values sample more low-probability tokens",
                            )
                            repetition_penalty = gr.Slider(
                                label="Repetition penalty",
                                value=1.2,
                                minimum=1.0,
                                maximum=2.0,
                                step=0.05,
                                interactive=True,
                                info="Penalize repeated tokens",
                            )

        with gr.Row():
            with gr.Column():
                instruction = gr.Textbox(
                    placeholder="Enter your query here",
                    lines=5,
                    label="Input",
                    elem_id="q-input",
                )
                submit = gr.Button("Generate", variant="primary")
                output = gr.Code(elem_id="q-output", lines=30, label="Output")
                gr.Markdown(disclaimer)
                with gr.Group(elem_id="share-btn-container"):
                    community_icon = gr.HTML(community_icon_html, visible=True)
                    loading_icon = gr.HTML(loading_icon_html, visible=True)
                    share_button = gr.Button(
                        "Share to community", elem_id="share-btn", visible=True
                    )
                gr.Examples(
                    examples=examples,
                    inputs=[instruction, max_new_tokens],
                    cache_examples=False,
                    fn=process_example,
                    outputs=[output],
                )

    submit.click(
        generate,
        inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty],
        outputs=[output],
    )
    share_button.click(None, [], [], _js=share_js)
demo.queue(concurrency_count=16).launch(debug=True)