import gradio as gr
import requests
from PIL import Image, ImageDraw, ImageFont
import random
from transformers import AutoProcessor, AutoModelForVision2Seq

# Load the model and processor
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")


def draw_bounding_boxes(image: Image, entities):
    draw = ImageDraw.Draw(image)
    width, height = image.size
    
    color_bank = [
        "#0AC2FF", "#30D5C8", "#F3C300", "#47FF0A", "#C2FF0A"
    ]
    
    try:
        font_size = 20
        font = ImageFont.truetype("assets/arial.ttf", font_size)
    except IOError:
        font_size = 20
        font = ImageFont.load_default()

    for entity in entities:
        label, _, boxes = entity
        for box in boxes:
            box_coords = [
                box[0] * width, box[1] * height, 
                box[2] * width, box[3] * height
            ]
            
            outline_color = random.choice(color_bank)
            text_fill_color = random.choice(color_bank)
            
            draw.rectangle(box_coords, outline=outline_color, width=4)
            text_position = (box_coords[0] + 5, box_coords[1] - font_size - 5)
            draw.text(text_position, label, fill=text_fill_color, font=font)

    return image

def highlight_entities(text, entities):
    for entity in entities:
        label = entity[0]
        text = text.replace(label, f"*{label}*")  # Highlighting by enclosing in asterisks
    return text

def process_image(image, prompt_option, custom_prompt):
    if not isinstance(image, Image.Image):
        image = Image.open(image)

    # Use the selected prompt option
    if prompt_option == "Brief":
        prompt = "<grounding>An image of"
    elif prompt_option == "Detailed":
        prompt = "<grounding> Describe this image in detail:"
    else:  # Custom
        prompt = custom_prompt

    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds=None,
        image_embeds_position_mask=inputs["image_embeds_position_mask"],
        use_cache=True,
        max_new_tokens=128,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    processed_text, entities = processor.post_process_generation(generated_text)

    # Draw bounding boxes on a copy of the image
    processed_image = draw_bounding_boxes(image.copy(), entities)

    highlighted_entities = highlight_entities(processed_text, entities)

    return processed_image, processed_text, entities, highlighted_entities

def clear_interface():
    return None, None, None, None


with gr.Blocks(gr.themes.Soft()) as demo:
    gr.Markdown("# Kosmos-2 VQA Demo")
    gr.Markdown("Run this space on your own hardware with this command: ```docker run -it -p 7860:7860 --platform=linux/amd64 \
	registry.hf.space/macadeliccc-kosmos-2-demo:latest python app.py```")

    with gr.Row(equal_height=True):
        image_input = gr.Image(type="pil", label="Upload Image")
        processed_image_output = gr.Image(label="Processed Image")
    with gr.Row(equal_height=True):
        with gr.Column():
            with gr.Accordion("Prompt Options"):
                prompt_option = gr.Radio(choices=["Brief", "Detailed", "Custom"], label="Select Prompt Option", value="Brief")
                custom_prompt_input = gr.Textbox(label="Custom Prompt", visible=False)

                def show_custom_prompt_input(prompt_option):
                    return prompt_option == "Custom"

                prompt_option.change(show_custom_prompt_input, inputs=[prompt_option], outputs=[custom_prompt_input])

    with gr.Row(equal_height=True):
        submit_button = gr.Button("Run Model")
        clear_button = gr.Button("Clear", elem_id="clear_button")

    with gr.Row(equal_height=True):
        with gr.Column():
            highlighted_entities = gr.Textbox(label="Processed Text")
        with gr.Column():
            with gr.Accordion("Entities"):
                entities_output = gr.JSON(label="Entities", elem_id="entities_output")
    

    # Define examples
    examples = [
        ["assets/snowman.jpg", "Custom", "<grounding> Question: Where is<phrase> the fire</phrase><object><patch_index_0005><patch_index_0911></object> next to? Answer:"],
        ["assets/traffic.jpg", "Detailed", "<grounding> Describe this image in detail:"],
        ["assets/umbrellas.jpg", "Brief", "<grounding>An image of"],
    ]
    gr.Examples(examples, inputs=[image_input, prompt_option, custom_prompt_input])

    with gr.Row(equal_height=True):
        with gr.Accordion("Additional Info"):
            gr.Markdown("This demo uses the [Kosmos-2]")
    submit_button.click(
        fn=process_image, 
        inputs=[image_input, prompt_option, custom_prompt_input], 
        outputs=[processed_image_output, highlighted_entities, entities_output]
    )

    clear_button.click(
        fn=clear_interface,
        inputs=[],
        outputs=[image_input, processed_image_output, highlighted_entities, entities_output]
    )



demo.launch()