# Fork of the SantaCoder demo (https://huggingface.co/spaces/bigcode/santacoder-demo)

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import os
import torch
from typing import Union, Tuple, List


description = """# <p style="text-align: center; color: #292b47;"> 🏎️ <span style='color: #3264ff;'>DeciCoder:</span> A Fast Code Generation ModelπŸ’¨ </p>
<span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/deci/decicoder" style="color: #3264ff;">DeciCoder</a>! 
DeciCoder is a 1B parameter code generation model trained on The Stack dataset and released under an Apache 2.0 license. It's capable of writing code in Python, 
JavaScript, and Java. It's a code-completion model, not an instruction-tuned model; you should prompt the model with a function signature and docstring 
and let it complete the rest. The model can also do infilling, specify where you would like the model to complete code with the <span style='color: #3264ff;'>&lt;FILL_HERE&gt;</span>
token.</span>"""

token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
device="cuda" if torch.cuda.is_available() else "cpu"


FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_PAD = "<fim_pad>"
EOD = "<|endoftext|>"

GENERATION_TITLE= "<p style='font-size: 24px; color: #292b47;'>πŸ’» Your generated code:</p>"

tokenizer_fim = AutoTokenizer.from_pretrained("bigcode/starcoder", use_auth_token=token, padding_side="left")

tokenizer_fim.add_special_tokens({
  "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
  "pad_token": EOD,
})

tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder", use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained("Deci/DeciCoder-1b", trust_remote_code=True, use_auth_token=token).to(device)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)

def post_processing(prompt: str, completion: str) -> str:
    """
    Post-processes the generated code completion with HTML styling.

    Args:
        prompt (str): The input code prompt.
        completion (str): The generated code completion.

    Returns:
        str: The HTML-styled code with prompt and completion.
    """
    completion = "<span style='color: #ff5b86;'>" + completion + "</span>"
    prompt = "<span style='color: #7484b7;'>" + prompt + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prompt}{completion}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html
    

def post_processing_fim(prefix: str, middle: str, suffix: str) -> str:
    """
    Post-processes the FIM (fill in the middle) generated code with HTML styling.

    Args:
        prefix (str): The prefix part of the code.
        middle (str): The generated middle part of the code.
        suffix (str): The suffix part of the code.

    Returns:
        str: The HTML-styled code with prefix, middle, and suffix.
    """
    prefix = "<span style='color: #7484b7;'>" + prefix + "</span>"
    middle = "<span style='color: #ff5b86;'>" + middle + "</span>"
    suffix = "<span style='color: #7484b7;'>" + suffix + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prefix}{middle}{suffix}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def fim_generation(prompt: str, max_new_tokens: int, temperature: float) -> str:
    """
    Generates code for FIM (fill in the middle) task.

    Args:
        prompt (str): The input code prompt with <FILL_HERE> token.
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float): Sampling temperature for generation.

    Returns:
        str: The HTML-styled code with filled missing part.
    """
    prefix = prompt.split("<FILL_HERE>")[0]
    suffix = prompt.split("<FILL_HERE>")[1]
    [middle] = infill((prefix, suffix), max_new_tokens, temperature)
    return post_processing_fim(prefix, middle, suffix)

def extract_fim_part(s: str) -> str:
    """
    Extracts the FIM (fill in the middle) part from the generated string.

    Args:
        s (str): The generated string with FIM tokens.

    Returns:
        str: The extracted FIM part.
    """
    # Find the index of 
    start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
    stop = s.find(EOD, start) or len(s)
    return s[start:stop]

def infill(prefix_suffix_tuples: Union[Tuple[str, str], List[Tuple[str, str]]], max_new_tokens: int, temperature: float) -> List[str]:
    """
    Generates the infill for the given prefix and suffix tuples.

    Args:
        prefix_suffix_tuples (Union[Tuple[str, str], List[Tuple[str, str]]]): Prefix and suffix tuples.
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float): Sampling temperature for generation.

    Returns:
        List[str]: The list of generated infill strings.
    """
    if type(prefix_suffix_tuples) == tuple:
        prefix_suffix_tuples = [prefix_suffix_tuples]
        
    prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer_fim(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    # WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
    return [        
        extract_fim_part(tokenizer_fim.decode(tensor, skip_special_tokens=False)) for tensor in outputs
    ]

def code_generation(prompt: str, max_new_tokens: int, temperature: float = 0.2, seed: int = 42) -> str:
    """
    Generates code based on the given prompt. Handles both regular and FIM (Fill-In-Missing) generation.

    Args:
        prompt (str): The input code prompt.
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float, optional): Sampling temperature for generation. Defaults to 0.2.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.

    Returns:
        str: The HTML-styled generated code.
    """    
    if "<FILL_HERE>" in prompt:
        return fim_generation(prompt, max_new_tokens, temperature=temperature)
    else:
        completion = pipe(prompt, do_sample=True, top_p=0.95, temperature=temperature, max_new_tokens=max_new_tokens)[0]['generated_text']
        completion = completion[len(prompt):]
        return post_processing(prompt, completion)

demo = gr.Blocks(
    css=".gradio-container {background-color: white; color: #292b47}"
)
with demo:
    with gr.Row():
        _, colum_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
        with colum_2:
            gr.Markdown(value=description)
            code = gr.Code(lines=5, language="python", label="Input code", value="def nth_element_in_fibonnaci(element):\n    \"\"\"Returns the nth element of the Fibonnaci sequence.\"\"\"")
            
            with gr.Accordion("Additional settings", open=True):
                max_new_tokens= gr.Slider(
                    minimum=8,
                    maximum=2048,
                    step=1,
                    value=75,
                    label="Number of tokens to generate",
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.5,
                    step=0.01,
                    value=0.2,
                    label="Temperature",
                )
                seed = gr.inputs.Number(
                    default=42,
                    label="Enter a seed value (integer)"
                )
            run = gr.Button(value="πŸ‘¨πŸ½β€πŸ’» Generate code", size='lg')
            output = gr.HTML(label="πŸ’» Your generated code")
            

    event = run.click(code_generation, [code, max_new_tokens, temperature, seed], output, api_name="predict")
    gr.HTML(label="Keep in touch", value="<img src='https://huggingface.co/spaces/Deci/DeciCoder-Demo/resolve/main/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")

demo.launch()