import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import TextIteratorStreamer
from datasets import load_dataset
import spaces
import pandas as pd

rekaeval = "RekaAI/VibeEval"
dataset = load_dataset(rekaeval, split="test")
df = pd.DataFrame(dataset)
df = df[['media_url', 'prompt', 'reference']]
df_markdown = df[['media_url', 'prompt']].copy()

# Function to convert URL to HTML img tag
def mediaurl_to_img_tag(url):
    return f'<img src="{url}">'

# Apply the function to the DataFrame column
df_markdown['media_url'] = df_markdown['media_url'].apply(mediaurl_to_img_tag)


PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://avatars.githubusercontent.com/u/51063788?s=400&u=479ecc9d93d8a373b5c2e69ebe846f394811e94a&v=4)" style="width:40%" opacity="0.30">
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-Llama3-8B With REKA Vibe-Eval</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Test your Vision LLMs with new Vibe-Evals from REKA</p>
</div>
"""
title="Testing LLaVA-Llama3-8b with Reka's Vibe-Eval"
description="Evaluate <a href='https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers'>LLaVA-Llama3-8B</a> on <b.REKA Vibe-Evals</b>. Click on a row in the Eval dataset and start chatting about it."

CSS ="""
.contain { display: flex !important; flex-direction: column !important; }
#component-0 { height: 100% !important; }
#chatbot { flex-grow: 1 !important; }
"""

model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)
model.to("cuda:0")
model.generation_config.eos_token_id = 128009


@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        # message["files"][-1] is a Dict or just a string
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            # Handle the case where image is None
            gr.Error("You need to upload an image for LLaVA to work.")
    except NameError:
        # Handle the case where 'image' is not defined at all
        gr.Error("You need to upload an image for LLaVA to work.")

    prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    # print(f"prompt: {prompt}")
    image = Image.open(image)
    inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)

    streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    # print(f"text_prompt: {text_prompt}")

    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        # find <|eot_id|> and remove it from the new_text
        if "<|eot_id|>" in new_text:
            new_text = new_text.split("<|eot_id|>")[0]
        buffer += new_text

        # generated_text_without_prompt = buffer[len(text_prompt):]
        generated_text_without_prompt = buffer
        # print(generated_text_without_prompt)
        time.sleep(0.06)
        # print(f"new_text: {generated_text_without_prompt}")
        yield generated_text_without_prompt


chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1, elem_id='chatbot')
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False, scale=1)
tmp = '''with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="Testing LLaVA-Llama3-8b with Reka's Vibe-Eval",
    examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
              {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
    description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )'''

with gr.Blocks(fill_height=True, css=CSS) as demo:
  gr.HTML(f'<h1><center>{title}</center></h1>')
  gr.HTML(f'<center>{description}</center>') 
  with gr.Row(equal_height=True):
    with gr.Column():
      gr.ChatInterface(
            fn=bot_streaming,
            stop_btn="Stop Generation",
            multimodal=True,
            textbox=chat_input,
            chatbot=chatbot,
        )
    with gr.Column():
      with gr.Accordion('Open for looking at Ground Truth:', open=False):
          refrence = gr.Markdown()
      with gr.Row():
        b1 = gr.Button("Previous", interactive=False)
        b2 = gr.Button("Next")
      reka = gr.Dataframe(value=df_markdown[0:5], label='Reka-Vibe-Eval', datatype=['markdown', 'str'], wrap=False, interactive=False, height=700)
      num_start = gr.Number(visible=False, value=0)
      num_end = gr.Number(visible=False, value=4)
    
  def get_example(reka, start, evt: gr.SelectData):
      print(f'evt.value = {evt.value}')
      print(f'evt.index = {evt.index}')
      x = evt.index[0] + start
      image = df.iloc[x, 0]  
      prompt = df.iloc[x, 1]
      refrence = df.iloc[x, 2]
      print(f'image = {image}')
      print(f'prompt = {prompt}')
      example = {"text": prompt, "files": [image]}
      return example, refrence

  def display_next(dataframe, end):
    print(f'initial value of end = {end}')
    start = (end  or dataframe.index[-1]) + 1
    end = start + 4
    df_images = df_markdown.loc[start:end]
    print(f'returned value of end = {end}')
    print(f'returned value of start = {start}')
    return df_images, end, start, gr.Button(interactive=True)

  def display_previous(dataframe, start):
    print(f'initial value of start = {start}')
    end = (start  or dataframe.index[-1]) 
    start = end - 5
    df_images = df_markdown.loc[start:end]
    print(f'returned value of start = {start}')
    print(f'returned value of end = {end}')
    return df_images, end, start, gr.Button(interactive=False) if start==0 else gr.Button(interactive=True)

  reka.select(get_example, [reka,num_start], [chat_input, refrence], show_progress="hidden")
  b2.click(fn=display_next, inputs= [reka, num_end ], outputs=[reka, num_end, num_start, b1], api_name="next_rows", show_progress=False)
  b1.click(fn=display_previous, inputs= [reka, num_start ], outputs=[reka, num_end, num_start, b1], api_name="previous_rows")


demo.queue()
demo.launch(debug=True)