import spaces

import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)
from io import BytesIO
import requests
import os
from conversation import Conversation, SeparatorStyle

model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"

disable_torch_init()
model_name = get_model_name_from_path(model_id)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_id, None, model_name
)

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    elif os.path.exists(image_file):
        image = Image.open(image_file).convert("RGB")
    else:
        raise FileNotFoundError(f"Görüntü dosyası {image_file} bulunamadı.")
    return image

def infer_single_image(model_id, image_file, prompt):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    conv = Conversation(
        system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
        roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
        version="llama3",
        messages=[],
        offset=0,
        sep_style=SeparatorStyle.MPT,
        sep="<|eot_id|>",
    )
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    full_prompt = conv.get_prompt()

    print("full prompt: ", full_prompt)

    image = load_image(image_file)
    image_tensor = process_images(
        [image],
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    input_ids = (
        tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            image_sizes=[image.size],
            do_sample=False,
            max_new_tokens=512,
            use_cache=True,
        )

    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    return output

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
    except NameError:
        gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")

    prompt = message['text']

    result = infer_single_image(model_id, image, prompt)
    
    print(result)
    
    yield result

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Mesaj girin veya dosya yükleyin...", show_label=False)

with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="Cosmos LLaVA",
        examples=[{"text": "Bu kitabın adı ne?", "files": ["./book.jpg"]},
                  {"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
                  {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
        description="",
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)