#!/usr/bin/env python

import os
import string
import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoProcessor, BitsAndBytesConfig, Blip2ForConditionalGeneration

# 스타일 상수 정의
CUSTOM_CSS = """
.container {
    max-width: 1000px;
    margin: auto;
    padding: 2rem;
    background: linear-gradient(to bottom right, #ffffff, #f8f9fa);
    border-radius: 15px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}

.title {
    font-size: 2.5rem;
    color: #1a73e8;
    text-align: center;
    margin-bottom: 2rem;
    font-weight: bold;
}

.tab-nav {
    background: #f8f9fa;
    border-radius: 10px;
    padding: 0.5rem;
    margin-bottom: 1rem;
}

.input-box {
    border: 2px solid #e0e0e0;
    border-radius: 8px;
    transition: all 0.3s ease;
}

.input-box:focus {
    border-color: #1a73e8;
    box-shadow: 0 0 0 2px rgba(26, 115, 232, 0.2);
}

.button-primary {
    background: #1a73e8;
    color: white;
    padding: 0.75rem 1.5rem;
    border-radius: 8px;
    border: none;
    cursor: pointer;
    transition: all 0.3s ease;
}

.button-primary:hover {
    background: #1557b0;
    transform: translateY(-1px);
}

.button-secondary {
    background: #f8f9fa;
    color: #1a73e8;
    border: 1px solid #1a73e8;
    padding: 0.75rem 1.5rem;
    border-radius: 8px;
    cursor: pointer;
    transition: all 0.3s ease;
}

.button-secondary:hover {
    background: #e8f0fe;
}

.output-box {
    background: #ffffff;
    border-radius: 8px;
    padding: 1rem;
    margin-top: 1rem;
    border: 1px solid #e0e0e0;
}

.chatbot-message {
    padding: 1rem;
    margin: 0.5rem 0;
    border-radius: 8px;
    background: #f8f9fa;
}

.advanced-settings {
    background: #ffffff;
    border-radius: 8px;
    padding: 1rem;
    margin-top: 1rem;
}

.slider-container {
    padding: 0.5rem;
    background: #f8f9fa;
    border-radius: 6px;
}

.examples-container {
    margin-top: 2rem;
    padding: 1rem;
    background: #ffffff;
    border-radius: 8px;
    box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
"""

DESCRIPTION = """
<div class="title">
    🖼️ BLIP-2 Visual Intelligence System
</div>
<p style='text-align: center; color: #666;'>
    Advanced AI system for image understanding and natural conversation
</p>
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p style='color: #dc3545;'>Running on CPU 🥶 This demo requires GPU to function properly.</p>"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MODEL_ID_OPT_2_7B = "Salesforce/blip2-opt-2.7b"
MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
MODEL_ID_FLAN_T5_XL = "Salesforce/blip2-flan-t5-xl"
MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)

if MODEL_ID not in [MODEL_ID_OPT_2_7B, MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XL, MODEL_ID_FLAN_T5_XXL]:
    error_message = f"Invalid MODEL_ID: {MODEL_ID}"
    raise ValueError(error_message)

if torch.cuda.is_available():
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    model = Blip2ForConditionalGeneration.from_pretrained(
        MODEL_ID,
        device_map="auto",
        quantization_config=BitsAndBytesConfig(load_in_8bit=True)
    )

@spaces.GPU
def generate_caption(
    image: PIL.Image.Image,
    decoding_method: str = "Nucleus sampling",
    temperature: float = 1.0,
    length_penalty: float = 1.0,
    repetition_penalty: float = 1.5,
    max_length: int = 50,
    min_length: int = 1,
    num_beams: int = 5,
    top_p: float = 0.9,
) -> str:
    inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        pixel_values=inputs.pixel_values,
        do_sample=decoding_method == "Nucleus sampling",
        temperature=temperature,
        length_penalty=length_penalty,
        repetition_penalty=repetition_penalty,
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        top_p=top_p,
    )
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

@spaces.GPU
def answer_question(
    image: PIL.Image.Image,
    prompt: str,
    decoding_method: str = "Nucleus sampling",
    temperature: float = 1.0,
    length_penalty: float = 1.0,
    repetition_penalty: float = 1.5,
    max_length: int = 50,
    min_length: int = 1,
    num_beams: int = 5,
    top_p: float = 0.9,
) -> str:
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        **inputs,
        do_sample=decoding_method == "Nucleus sampling",
        temperature=temperature,
        length_penalty=length_penalty,
        repetition_penalty=repetition_penalty,
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        top_p=top_p,
    )
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

def postprocess_output(output: str) -> str:
    if output and output[-1] not in string.punctuation:
        output += "."
    return output

def chat(
    image: PIL.Image.Image,
    text: str,
    decoding_method: str = "Nucleus sampling",
    temperature: float = 1.0,
    length_penalty: float = 1.0,
    repetition_penalty: float = 1.5,
    max_length: int = 50,
    min_length: int = 1,
    num_beams: int = 5,
    top_p: float = 0.9,
    history_orig: list[str] | None = None,
    history_qa: list[str] | None = None,
) -> tuple[list[tuple[str, str]], list[str], list[str]]:
    history_orig = history_orig or []
    history_qa = history_qa or []
    history_orig.append(text)
    text_qa = f"Question: {text} Answer:"
    history_qa.append(text_qa)
    prompt = " ".join(history_qa)

    output = answer_question(
        image=image,
        prompt=prompt,
        decoding_method=decoding_method,
        temperature=temperature,
        length_penalty=length_penalty,
        repetition_penalty=repetition_penalty,
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        top_p=top_p,
    )
    output = postprocess_output(output)
    history_orig.append(output)
    history_qa.append(output)

    chat_val = list(zip(history_orig[0::2], history_orig[1::2], strict=False))
    return chat_val, history_orig, history_qa

chat.zerogpu = True  # type: ignore

examples = [
    [
        "images/house.png",
        "How could someone get out of the house?",
    ],
    [
        "images/flower.jpg",
        "What is this flower and where is it's origin?",
    ],
    [
        "images/pizza.jpg",
        "What are steps to cook it?",
    ],
    [
        "images/sunset.jpg",
        "Here is a romantic message going along the photo:",
    ],
    [
        "images/forbidden_city.webp",
        "In what dynasties was this place built?",
    ],
]

with gr.Blocks(css=CUSTOM_CSS) as demo:
    gr.Markdown(DESCRIPTION)
    
    with gr.Group(elem_classes="container"):
        with gr.Row():
            with gr.Column(scale=1):
                image = gr.Image(
                    type="pil",
                    label="Upload Image",
                    elem_classes="input-box"
                )
            
            with gr.Column(scale=2):
                with gr.Tabs(elem_classes="tab-nav"):
                    with gr.Tab(label="✨ Image Captioning"):
                        caption_button = gr.Button(
                            "Generate Caption",
                            elem_classes="button-primary"
                        )
                        caption_output = gr.Textbox(
                            label="Generated Caption",
                            elem_classes="output-box"
                        )
                        
                    with gr.Tab(label="💭 Visual Q&A"):
                        chatbot = gr.Chatbot(
                            elem_classes="chatbot-message"
                        )
                        history_orig = gr.State(value=[])
                        history_qa = gr.State(value=[])
                        vqa_input = gr.Textbox(
                            placeholder="Ask me anything about the image...",
                            elem_classes="input-box"
                        )
                        
                        with gr.Row():
                            clear_button = gr.Button(
                                "Clear Chat",
                                elem_classes="button-secondary"
                            )
                            submit_button = gr.Button(
                                "Send Message",
                                elem_classes="button-primary"
                            )

        with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
            with gr.Row():
                with gr.Column():
                    text_decoding_method = gr.Radio(
                        choices=["Beam search", "Nucleus sampling"],
                        value="Nucleus sampling",
                        label="Decoding Method"
                    )
                    temperature = gr.Slider(
                        minimum=0.5,
                        maximum=1.0,
                        value=1.0,
                        label="Temperature",
                        info="Used with nucleus sampling",
                        elem_classes="slider-container"
                    )
                    length_penalty = gr.Slider(
                        minimum=-1.0,
                        maximum=2.0,
                        value=1.0,
                        label="Length Penalty",
                        info="Set to larger for longer sequence",
                        elem_classes="slider-container"
                    )
                with gr.Column():
                    repetition_penalty = gr.Slider(
                        minimum=1.0,
                        maximum=5.0,
                        value=1.5,
                        label="Repetition Penalty",
                        info="Larger value prevents repetition",
                        elem_classes="slider-container"
                    )
                    max_length = gr.Slider(
                        minimum=20,
                        maximum=512,
                        value=50,
                        label="Max Length",
                        elem_classes="slider-container"
                    )
                    min_length = gr.Slider(
                        minimum=1,
                        maximum=100,
                        value=1,
                        label="Min Length",
                        elem_classes="slider-container"
                    )
                    num_beams = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=5,
                        label="Number of Beams",
                        elem_classes="slider-container"
                    )
                    top_p = gr.Slider(
                        minimum=0.5,
                        maximum=1.0,
                        value=0.9,
                        label="Top P",
                        info="Used with nucleus sampling",
                        elem_classes="slider-container"
                    )

    with gr.Group(elem_classes="examples-container"):
        gr.Examples(
            examples=examples,
            inputs=[image, vqa_input],
            label="Try these examples"
        )

    # Event handlers
    caption_button.click(
        fn=generate_caption,
        inputs=[
            image,
            text_decoding_method,
            temperature,
            length_penalty,
            repetition_penalty,
            max_length,
            min_length,
            num_beams,
            top_p,
        ],
        outputs=caption_output,
        api_name="caption",
    )

    chat_inputs = [
        image,
        vqa_input,
        text_decoding_method,
        temperature,
        length_penalty,
        repetition_penalty,
        max_length,
        min_length,
        num_beams,
        top_p,
        history_orig,
        history_qa,
    ]
    chat_outputs = [
        chatbot,
        history_orig,
        history_qa,
    ]

    vqa_input.submit(
        fn=chat,
        inputs=chat_inputs,
        outputs=chat_outputs
    ).success(
        fn=lambda: "",
        outputs=vqa_input,
        queue=False,
        api_name=False
    )

    clear_button.click(
        fn=lambda: ("", [], [], []),
        inputs=None,
        outputs=[
            vqa_input,
            chatbot,
            history_orig,
            history_qa,
        ],
        queue=False,
        api_name="clear"
    )

    image.change(
        fn=lambda: ("", [], [], []),
        inputs=None,
        outputs=[
            caption_output,
            chatbot,
            history_orig,
            history_qa,
        ],
        queue=False
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch()
        outputs=vqa_input,
        queue=False,
        api_name=False
    )