from huggingface_hub import InferenceClient
import gradio as gr
import random
     
from prompts import GAME_MASTER, COMPRESS_HISTORY, ADJUST_STATS
def format_prompt(message, history):
    prompt=""

    prompt = "<s>"
    
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt


temperature=0.99
top_p=0.95
repetition_penalty=1.0

def compress_history(history,temperature=temperature,top_p=top_p,repetition_penalty=repetition_penalty):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    print("COMPRESSING")
    formatted_prompt=f"{COMPRESS_HISTORY.format(history=history)}"
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=1024,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(1,99999999999)
        #seed=42,
    )
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
    return output

MAX_HISTORY=100
opts=[]
def generate(prompt, history,max_new_tokens,health,temperature=temperature,top_p=top_p,repetition_penalty=repetition_penalty):
    opts.clear()
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(1,99999999999)
        #seed=42,
    )
    cnt=0
    stats=health
    history1=history
    '''
    stats="*******************\n"
    for eac in health:
        stats+=f'{eac}\n'
    stats+="*******************\n"
    '''
    for ea in history:
        print (ea)
        for l in ea:
            print (l)
            cnt+=len(l.split("\n"))
    print(f'cnt:: {cnt}')
    if cnt > MAX_HISTORY:
        history1 = compress_history(str(history), temperature, top_p, repetition_penalty)
    formatted_prompt = format_prompt(f"{GAME_MASTER.format(history=history1,stats=stats,dice=random.randint(1,100))}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        if history:
            yield [(prompt,output)],stats,None,None
        else:
            yield [(prompt,output)],stats,None,None
    generate_kwargs2 = dict(
        temperature=temperature,
        max_new_tokens=128,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(1,99999999999)
        #seed=42,
    )
    #history=""
    #formatted_prompt2 = format_prompt(f"{ADJUST_STATS.format(history=output,health=health)}, {prompt}", history)
    #stream2 = client.text_generation(f"{ADJUST_STATS.format(history=output,health=health)}", **generate_kwargs2, stream=True, details=True, return_full_text=False)  
    #output2=""
    #for response in stream2:
    #    output2 += response.token.text
       
    lines = output.strip().strip("\n").split("\n")
    skills=[]
    skill_dict={}
    option_drop=[]
    new_stat="*******************\n"
    for i,line in enumerate(lines):
        if "Choices:" in line:
            for z in range(1,5):
                try:
                    if f'{z}' in lines[i+z]:
                        print(lines[i+z].split(" ",1)[1])
                        opts.append(lines[i+z].split(" ",1)[1])
                except Exception:
                    pass
        if ": " in line:
            try:
                lab_1 = line.split(": ")[0]
                
                skill_1 = line.split(": ")[1].split(" ")[0].split("<")[0]
                skill_1=int(skill_1)
                skill_dict[lab_1]=skill_1
                #skill ={lab_1:skill_1}
                
                new_stat += f'{lab_1}: {skill_1}\n'
                
                print(skills)
            except Exception as e:
                print (f'--Error :: {e}')
            print(f'Line:: {line}')
    skills.append(skill_dict)
    new_stat+="*******************\n"
    stats=new_stat
    option_drop=gr.Dropdown(label="Choices", choices=[e for e in opts])

    if history:
        history.append((prompt,output))
        yield history,stats,skills,option_drop
    else:
        yield [(prompt,output)],stats,skills,option_drop

def clear_fn():
    return None,None

base_stats=[
    {"Health":100,"Power":20,"Strength":24},
]
text_stats='''*******************
Health: 100
Power: 20
Strength: 24
*******************    
'''
    
with gr.Blocks() as app:
    gr.HTML("""<center><h1>Mixtral 8x7B RPG</h1><h3>Role Playing Game Master</h3>""")
    with gr.Group():
        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(label="Mixtral 8x7B Game Master",height=500, layout='panel', show_copy_button=True)
                with gr.Row():
                    with gr.Column(scale=3):
                        opt=gr.Dropdown(label="Choices",choices=["Start a new game"],allow_custom_value=True, value="Start a new game", interactive=True)
                        #prompt=gr.Textbox(label = "Prompt", value="Start a new game")
                    with gr.Column(scale=2):
                        button=gr.Button()
                    #models_dd=gr.Dropdown(choices=[m for m in return_list],interactive=True)
                with gr.Row():
                    stop_button=gr.Button("Stop")
                    clear_btn = gr.Button("Clear")
                with gr.Row():
                    tokens = gr.Slider(label="Max new tokens",value=2096,minimum=0,maximum=1048*10,step=64,interactive=False, visible=False,info="The maximum numbers of new tokens")
            with gr.Column(scale=1):
                json_out=gr.JSON(value=base_stats)
                char_stats=gr.Textbox(value=text_stats)
                textboxes = []
                if opts:
                    textboxes.clear()
                    for i in range(len(opts)-1):
                        t = gr.Button(f"{opts[i]}")
                    textboxes.append(t)
    #text=gr.JSON()
    #inp_query.change(search_models,inp_query,models_dd)
    #test_b=test_btn.click(itt,url,e_box)
    clear_btn.click(clear_fn,None,[opt,chatbot])
    go=button.click(generate,[opt,chatbot,tokens,char_stats],[chatbot,char_stats,json_out,opt])
    stop_button.click(None,None,None,cancels=[go])
app.launch(show_api=False) 



'''
examples=[["Start the Game", None, None, None, None, None, ],
          ["Start a Game based in the year 1322", None, None, None, None, None,],
         ]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral RPG Game Master",
    examples=examples,
    concurrency_limit=20,
).launch(share=True,show_api=True)
'''