from huggingface_hub import InferenceClient, HfApi, upload_file
import datetime
import gradio as gr
import random
import prompts
import json
import uuid
import os



token=os.environ.get("HF_TOKEN")
username="omnibus"
dataset_name="tmp"
api=HfApi(token="")
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

history = []
hist_out= []
summary =[]
main_point=[]
summary.append("")
main_point.append("")

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

agents =[
    "COMMENTER",
    "BLOG_POSTER",
    "REPLY_TO_COMMENTER",
    "COMPRESS_HISTORY_PROMPT"
]

temperature=0.9
max_new_tokens=256
max_new_tokens2=10480
top_p=0.95
repetition_penalty=1.0,

def compress_history(formatted_prompt):
    
    seed = random.randint(1,1111111111111111)
    agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])
    
    system_prompt=agent
    temperature = 0.9
    if temperature < 1e-2:
        temperature = 1e-2

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=10480,
        top_p=0.95,
        repetition_penalty=1.0,
        do_sample=True,
        seed=seed,
    )
    #history.append((prompt,""))
    #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    formatted_prompt = formatted_prompt
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
    #history.append((output,history))
    print(output)
    print(main_point[0])
    return output


def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,):
#def question_generate(prompt, history):
    seed = random.randint(1,1111111111111111)
    agent=prompts.COMMENTER.format(focus=main_point[0])
    system_prompt=agent
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=seed,
    )
    #history.append((prompt,""))
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
    #history.append((output,history))

    return output

def blog_poster_reply(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
#def question_generate(prompt, history):
    seed = random.randint(1,1111111111111111)
    agent=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0])
    system_prompt=agent
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=seed,
    )
    #history.append((prompt,""))
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
    #history.append((output,history))

    return output


    
def create_valid_filename(invalid_filename: str) -> str:
    """Converts invalid characters in a string to be suitable for a filename."""
    invalid_filename.replace(" ","-")
    valid_chars = '-'.join(invalid_filename.split())
    allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
                      'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
                      'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
                      'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-')
    return ''.join(char for char in valid_chars if char in allowed_chars)





    
def load_html(inp):
    ht=""
    if inp:
        for ea in inp:
            outp,prom=ea
            print(f'outp:: {outp}')
            print(f'prom:: {prom}')
            ht+=f"""<div class="div_box">
            <div class="resp">{outp}</div>
            <div class="resp">{prom}</div>
            </div>"""
    with open('index.html','r') as h:
        html=h.read()
        html = html.replace("$body",f"{ht}")
    h.close()
    return html



def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
    html_out=""
    #main_point[0]=prompt
    #print(datetime.datetime.now())
    uid=uuid.uuid4()
    current_time = str(datetime.datetime.now())
    title=""
    filename=create_valid_filename(f'{current_time}---{title}')
    
    current_time=current_time.replace(":","-")
    current_time=current_time.replace(".","-")
    print (current_time)
    agent=prompts.BLOG_POSTER
    system_prompt=agent
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    hist_out=[]
    sum_out=[]
    json_hist={}
    json_obj={}
    full_conv=[]
    post_cnt=1
    while True:
        seed = random.randint(1,1111111111111111)
        if post_cnt==1:
            generate_kwargs = dict(
                temperature=temperature,
                max_new_tokens=max_new_tokens2,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                do_sample=True,
                seed=seed,
            )    
            if prompt.startswith(' \"'):
                prompt=prompt.strip(' \"')

            formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
            
            post_cnt+=1
        else:
            system_prompt=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0])
    
            generate_kwargs = dict(
                temperature=temperature,
                max_new_tokens=max_new_tokens2,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                do_sample=True,
                seed=seed,
            )    
            if prompt.startswith(' \"'):
                prompt=prompt.strip(' \"')
            
            formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
        if len(formatted_prompt) < (40000):
            print(len(formatted_prompt))

            
            stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
            output = ""
            #if history:
            #    yield history
        
            for response in stream:
                output += response.token.text
                yield '', [(prompt,output)],summary[0],json_obj, json_hist,html_out
            
            if not title:
                for line in output.split("\n"):
                    if "title" in line.lower() and ":" in line.lower():
                        title = line.split(":")[1]
                        print(f'title:: {title}')
                        filename=create_valid_filename(f'{current_time}---{title}')
                        
            out_json = {"prompt":prompt,"output":output}
        
            prompt = question_generate(output, history)
            #output += prompt
            history.append((prompt,output))
            print ( f'Prompt:: {len(prompt)}')
            #print ( f'output:: {output}')
            print ( f'history:: {len(formatted_prompt)}')
            hist_out.append(out_json)
            #try:
            #    for ea in 
            with open(f'{uid}.json', 'w') as f:
                json_hist=json.dumps(hist_out, indent=4)
                f.write(json_hist)
            f.close()
        
            upload_file(
                path_or_fileobj =f"{uid}.json", 
                path_in_repo = f"book1/{filename}.json", 
                repo_id =f"{username}/{dataset_name}", 
                repo_type = "dataset", 
                token=token,
            )
        else:
            formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history)
            
            #current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0]
            #filename=f'{filename}-{current_time}'
            history = []
            output = compress_history(formatted_prompt)
            summary[0]=output
            sum_json = {"summary":summary[0]}
            sum_out.append(sum_json)
            with open(f'{uid}-sum.json', 'w') as f:
                json_obj=json.dumps(sum_out, indent=4)
                f.write(json_obj)
            f.close()
            upload_file(
                path_or_fileobj =f"{uid}-sum.json", 
                path_in_repo = f"book1/{filename}-summary.json", 
                repo_id =f"{username}/{dataset_name}", 
                repo_type = "dataset", 
                token=token,
            )

           
            prompt = question_generate(output, history)
        main_point[0]=prompt
        full_conv.append((output,prompt))


        html_out=load_html(full_conv)
        yield prompt, history, summary[0],json_obj,json_hist,html_out
    return prompt, history, summary[0],json_obj,json_hist,html_out




with gr.Blocks() as app:
    html = gr.HTML()

    chatbot=gr.Chatbot()
    msg = gr.Textbox()
    with gr.Row():
        submit_b = gr.Button()
        stop_b = gr.Button("Stop")
        clear = gr.ClearButton([msg, chatbot])
    sumbox=gr.Textbox("Summary", max_lines=100)
    with gr.Column():
        sum_out_box=gr.JSON(label="Summaries")
        hist_out_box=gr.JSON(label="History")
        
    sub_b = submit_b.click(generate, [msg,chatbot],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
    sub_e = msg.submit(generate, [msg, chatbot], [msg, chatbot,sumbox,sum_out_box,hist_out_box,html])
    stop_b.click(None,None,None, cancels=[sub_b,sub_e])
    
    
    
    app.load(load_html,None,html)
app.queue(default_concurrency_limit=20).launch()