import json
import os
import shutil
import requests

import gradio as gr
from huggingface_hub import Repository, InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = "https://api-inference.huggingface.co/models/DataAnalyticsLab/PersianGPT-FT-Grover"
BOT_NAME = "PersianGPT-FT"

STOP_SEQUENCES = ["<|endoftext|>"]

EXAMPLES = [
    ["<$غزل$@بر لبم هر ذره داغی می توان کردن"],
    ["<$غزل$"],
    ["<$قصیده$"],
    ["<$مثنوی$"],
    ["<$غزل$@دراین سرای بی کسی، کسی به در نمی زند"]
    ]

client = InferenceClient(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

def format_prompt(message, history, system_prompt):
  prompt = ""
  if system_prompt:
    prompt += f"{system_prompt}"
  for user_prompt, bot_response in history:
    prompt += f"{user_prompt}"
    prompt += f"{bot_response}"
  prompt += f"""{message}"""
  return prompt

def generate(
    prompt, history, system_prompt="<|endoftext|>", temperature=0.9, max_new_tokens=100, top_p=0.95, repetition_penalty=1.0, seed=42,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        stop_sequences=STOP_SEQUENCES,
        do_sample=True,
        #seed=seed,
    )
    #seed = seed + 1
    formatted_prompt = format_prompt(prompt, history, system_prompt)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response

        for stop_str in STOP_SEQUENCES:
            if output.endswith(stop_str):
                output = output[:-len(stop_str)]
                output = output.rstrip()
                yield output
        yield output
    return output


additional_inputs=[
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=100,
        minimum=0,
        maximum=250,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

CSS = """
.gradio-container textarea {direction: rtl; white-space: pre-line;}
#component-11 #component-12 {direction: rtl; white-space: pre-line;}
p {direction: rtl; white-space: pre-line;}
"""
with gr.Blocks(css=CSS) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                PERSIAN GPT Trained by Mojtaba Valipour @ Data Analytics Lab
                """
            )

    gr.ChatInterface(
        generate, 
        examples=EXAMPLES,
        additional_inputs=additional_inputs
    ) 

demo.queue(concurrency_count=100, api_open=False).launch(show_api=False) #, share=True)