import gradio as gr
#import urllib.request
import requests
import bs4
import lxml
import os
#import subprocess
from huggingface_hub import InferenceClient,HfApi
import random
import json
import datetime
from pypdf import PdfReader
import uuid
#from query import tasks
from agent import (
    PREFIX,
    COMPRESS_DATA_PROMPT,
    COMPRESS_DATA_PROMPT_SMALL,
    LOG_PROMPT,
    LOG_RESPONSE,
)
api=HfApi()

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

def find_all(url):
    return_list=[]
    print (url)
    #if action_input in query.tasks:
    print (f"trying URL:: {url}")        
    try:
        if url != "" and url != None:    
            out = []
            source = requests.get(url)
            #source = urllib.request.urlopen(url).read()
            soup = bs4.BeautifulSoup(source.content,'lxml')
            # title of the page
            print(soup.title)
            # get attributes:
            print(soup.title.name)
            # get values:
            print(soup.title.string)
            # beginning navigation:
            print(soup.title.parent.name)
            #rawp.append([tag.name for tag in soup.find_all()] )
            print([tag.name for tag in soup.find_all()])
            #rawp=(f'RAW TEXT RETURNED: {soup.text}')
            rawp=(f'RAW HTML RETURNED: {soup}')
            out.append(rawp)
            q=("a","p","span","content","article")
            for p in soup.find_all(q):
                out.append([{q:p.string,"parent":p.parent.name,"previous":[p.previous],"first-child":[b.name for b in p.children],"content":p}])
            #print (f'OUT :: {out}')
            '''
            c=0
            out = str(out)
            rl = len(out)
            print(f'rl:: {rl}')
            #for ea in out:
            for i in str(out):
                if i == " " or i=="," or i=="\n":
                    c +=1
            print (f'c:: {c}')
            if rl > MAX_DATA:
                print("compressing...")
                rawp = compress_data(c,purpose,task,out)    
            print (rawp)
            print (f'out:: {out}')
            '''
            print(rawp)
            return True, rawp
        else: 
            return False, "Enter Valid URL"
    except Exception as e:
        print (e)
        return False, f'Error: {e}'

        #else:
    #    history = "observation: The search query I used did not return a valid response"
        
    return "MAIN", None, history, task

def read_txt(txt_path):
    text=""
    with open(txt_path,"r") as f:
        text = f.read()
    f.close()
    print (text)
    return text

def read_pdf(pdf_path):
    text=""
    reader = PdfReader(f'{pdf_path}')
    number_of_pages = len(reader.pages)
    for i in range(number_of_pages):
        page = reader.pages[i]
        text = f'{text}\n{page.extract_text()}'
    print (text)
    return text

error_box=[]
def read_pdf_online(url):
    uid=uuid.uuid4()
    print(f"reading {url}")
    response = requests.get(url, stream=True)
    print(response.status_code)
    text=""
#################
    
#####################
    try:
        if response.status_code == 200:
            with open("test.pdf", "wb") as f:
                f.write(response.content)
            #f.close()
            #out = Path("./data.pdf")
            #print (out)
            reader = PdfReader("test.pdf")
            number_of_pages = len(reader.pages)
            print(number_of_pages)
            for i in range(number_of_pages):
                page = reader.pages[i]
                text = f'{text}\n{page.extract_text()}'
                print(f"PDF_TEXT:: {text}")
            return text
        else:
            text = response.status_code
            error_box.append(url)
            print(text)
            return text


    except Exception as e:
        print (e)
        return e


VERBOSE = True
MAX_HISTORY = 100
MAX_DATA = 20000

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt



def run_gpt(
    prompt_template,
    stop_tokens,
    max_tokens,
    seed,
    **prompt_kwargs,
):
    print(seed)
    timestamp=datetime.datetime.now()
    
    generate_kwargs = dict(
        temperature=0.9,
        max_new_tokens=max_tokens,
        top_p=0.95,
        repetition_penalty=1.0,
        do_sample=True,
        seed=seed,
    )
    
    content = PREFIX.format(
        timestamp=timestamp,
        purpose="Compile the provided data and complete the users task"
    ) + prompt_template.format(**prompt_kwargs)
    if VERBOSE:
        print(LOG_PROMPT.format(content))
    
    
    #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    #formatted_prompt = format_prompt(f'{content}', history)

    stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
    resp = ""
    for response in stream:
        resp += response.token.text
        #yield resp

    if VERBOSE:
        print(LOG_RESPONSE.format(resp))
    return resp

    
def compress_data(c, instruct, history):
    seed=random.randint(1,1000000000)
    
    print (c)
    #tot=len(purpose)
    #print(tot)
    divr=int(c)/MAX_DATA
    divi=int(divr)+1 if divr != int(divr) else int(divr)
    chunk = int(int(c)/divr)
    print(f'chunk:: {chunk}')
    print(f'divr:: {divr}')
    print (f'divi:: {divi}')
    out = []
    #out=""
    s=0
    e=chunk
    print(f'e:: {e}')
    new_history=""
    #task = f'Compile this data to fulfill the task: {task}, and complete the purpose: {purpose}\n'
    for z in range(divi):
        print(f's:e :: {s}:{e}')
        
        hist = history[s:e]
        
        resp = run_gpt(
            COMPRESS_DATA_PROMPT_SMALL,
            stop_tokens=["observation:", "task:", "action:", "thought:"],
            max_tokens=8192,
            seed=seed,
            direction=instruct,
            knowledge="",
            history=hist,
        )
        out.append(resp)
        new_history = resp
        print (resp)
        out+=resp
        e=e+chunk
        s=s+chunk
    return out

    
def compress_data_og(c, instruct, history):
    seed=random.randint(1,1000000000)
    
    print (c)
    #tot=len(purpose)
    #print(tot)
    divr=int(c)/MAX_DATA
    divi=int(divr)+1 if divr != int(divr) else int(divr)
    chunk = int(int(c)/divr)
    print(f'chunk:: {chunk}')
    print(f'divr:: {divr}')
    print (f'divi:: {divi}')
    out = []
    #out=""
    s=0
    e=chunk
    print(f'e:: {e}')
    new_history=""
    #task = f'Compile this data to fulfill the task: {task}, and complete the purpose: {purpose}\n'
    for z in range(divi):
        print(f's:e :: {s}:{e}')
        
        hist = history[s:e]
        
        resp = run_gpt(
            COMPRESS_DATA_PROMPT_SMALL,
            stop_tokens=["observation:", "task:", "action:", "thought:"],
            max_tokens=8192,
            seed=seed,
            direction=instruct,
            knowledge=new_history,
            history=hist,
        )
        
        new_history = resp
        print (resp)
        out+=resp
        e=e+chunk
        s=s+chunk
    
    resp = run_gpt(
        COMPRESS_DATA_PROMPT,
        stop_tokens=["observation:", "task:", "action:", "thought:"],
        max_tokens=8192,
        seed=seed,
        direction=instruct,
        knowledge=new_history,
        history="All data has been recieved.",
    )
    print ("final" + resp)
    #history = "observation: {}\n".format(resp)
    return resp



def summarize(inp,history,data=None,files=None,url=None,pdf_url=None,pdf_batch=None):
    json_box=[]
    if inp == "":
        inp = "Process this data"
    history.clear()
    history = [(inp,"Working on it...")] 
    yield "",history,error_box,json_box

    if pdf_batch.startswith("http"):
        c=0
        data=""
        for i in str(pdf_batch):
            if i==",":
                c+=1
        print (f'c:: {c}')

        try:
            for i in range(c+1):
                batch_url = pdf_batch.split(",",c)[i]
                bb = read_pdf_online(batch_url)
                data=f'{data}\nFile Name URL ({batch_url}):\n{bb}'
        except Exception as e:
            print(e)
            #data=f'{data}\nError reading URL ({batch_url})'
            

        
    if pdf_url.startswith("http"):
        print("PDF_URL")
        out = read_pdf_online(pdf_url)
        data=out
    if url.startswith("http"):
        val, out = find_all(url)
        if not val:
            data="Error"
            rawp = str(out)
        else:
            data=out
    if files:
        for i, file in enumerate(files):
            try: 
                print (file)
                if file.endswith(".pdf"):
                    zz=read_pdf(file)
                    print (zz)
                    data=f'{data}\nFile Name ({file}):\n{zz}'
                elif file.endswith(".txt"):
                    zz=read_txt(file)
                    print (zz)
                    data=f'{data}\nFile Name ({file}):\n{zz}'                
            except Exception as e:
                data=f'{data}\nError opening File Name ({file})'                
                print (e) 
    if data != "Error" and data != "":
        print(inp)
        out = str(data)
        rl = len(out)
        print(f'rl:: {rl}')
        c=1
        for i in str(out):
            if i == " " or i=="," or i=="\n":
                c +=1
        print (f'c:: {c}')
        
        json_out = compress_data(c,inp,out)  
        #json_box.append(json_out)

        json_object = json.dumps(json_out, indent=4)
        json_box.append(json_object)
        # Writing to sample.json
        #with open("tmp.json", "w") as outfile:
        #    outfile.write(json_object)
        #outfile.close()
        
        #json_box.append(json_out)
        out = str(json_out)
        rl = len(out)
        print(f'rl:: {rl}')
        c=1
        for i in str(out):
            if i == " " or i=="," or i=="\n":
                c +=1
        print (f'c2:: {c}')
        rawp = compress_data(c,inp,out)
    else:
        rawp = "Provide a valid data source"
    #print (rawp)
    #print (f'out:: {out}')
    #history += "observation: the search results are:\n {}\n".format(out)
    #task = "complete?"
    history.clear()
    history.append((inp,rawp))
    yield "", history,error_box,json_box

#################################
def clear_fn():
    return "",[(None,None)]

with gr.Blocks() as app:
    gr.HTML("""<center><h1>Mixtral 8x7B TLDR Summarizer + Web</h1><h3>Summarize Data of unlimited length</h3>""")
    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=3):
            prompt=gr.Textbox(label = "Instructions (optional)")
        with gr.Column(scale=1):
            button=gr.Button()
        
        #models_dd=gr.Dropdown(choices=[m for m in return_list],interactive=True)
    with gr.Row():
        stop_button=gr.Button("Stop")
        clear_btn = gr.Button("Clear")
    with gr.Row():
        with gr.Tab("Text"):
            data=gr.Textbox(label="Input Data (paste text)", lines=6)
        with gr.Tab("File"):
            file=gr.Files(label="Input File (.pdf .txt)")
        with gr.Tab("Raw HTML"):
            url = gr.Textbox(label="URL")
        with gr.Tab("PDF URL"):
            pdf_url = gr.Textbox(label="PDF URL")       
        with gr.Tab("PDF Batch"):
            pdf_batch = gr.Textbox(label="PDF Batch (comma separated)")
    json_out=gr.JSON()
    e_box=gr.Textbox()
    #text=gr.JSON()
    #inp_query.change(search_models,inp_query,models_dd)
    clear_btn.click(clear_fn,None,[prompt,chatbot])
    go=button.click(summarize,[prompt,chatbot,data,file,url,pdf_url,pdf_batch],[prompt,chatbot,e_box,json_out])
    stop_button.click(None,None,None,cancels=[go])
app.launch(server_port=7860,show_api=False)