import os
import subprocess
import spaces
import torch

import gradio as gr

from gradio_client.client import DEFAULT_TEMP_DIR
from playwright.sync_api import sync_playwright
from threading import Thread
from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from typing import List
from PIL import Image

from transformers.image_transforms import resize, to_channel_dimension_format


subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

DEVICE = torch.device("cuda")
PROCESSOR = AutoProcessor.from_pretrained(
    "HuggingFaceM4/VLM_WebSight_finetuned",
)
MODEL = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/VLM_WebSight_finetuned",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).to(DEVICE)
if MODEL.config.use_resampler:
    image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
else:
    image_seq_len = (
        MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
    ) ** 2
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids


## Utils

def convert_to_rgb(image):
    # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
    # for transparent images. The call to `alpha_composite` handles this case
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite

# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
# so this is a hack in order to redefine ONLY the transform method
def custom_transform(x):
    x = convert_to_rgb(x)
    x = to_numpy_array(x)
    x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
    x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
    x = PROCESSOR.image_processor.normalize(
        x,
        mean=PROCESSOR.image_processor.image_mean,
        std=PROCESSOR.image_processor.image_std
    )
    x = to_channel_dimension_format(x, ChannelDimension.FIRST)
    x = torch.tensor(x)
    return x

## End of Utils


IMAGE_GALLERY_PATHS = [
    f"example_images/{ex_image}"
    for ex_image in os.listdir(f"example_images")
]


def install_playwright():
    try:
        subprocess.run(["playwright", "install"], check=True)
        print("Playwright installation successful.")
    except subprocess.CalledProcessError as e:
        print(f"Error during Playwright installation: {e}")

install_playwright()


def add_file_gallery(
    selected_state: gr.SelectData,
    gallery_list: List[str]
):
    return Image.open(gallery_list.root[selected_state.index].image.path)


def render_webpage(
    html_css_code,
):
    with sync_playwright() as p:
        browser = p.chromium.launch(headless=True)
        context = browser.new_context(
            user_agent=(
                "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
                " Safari/537.36"
            )
        )
        page = context.new_page()
        page.set_content(html_css_code)
        page.wait_for_load_state("networkidle")
        output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
        _ = page.screenshot(path=output_path_screenshot, full_page=True)

        context.close()
        browser.close()

    return Image.open(output_path_screenshot)


@spaces.GPU(duration=300)
def model_inference(
    image,
):
    if image is None:
        raise ValueError("`image` is None. It should be a PIL image.")

    inputs = PROCESSOR.tokenizer(
        f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
        return_tensors="pt",
        add_special_tokens=False,
    )
    inputs["pixel_values"] = PROCESSOR.image_processor(
        [image],
        transform=custom_transform
    )
    inputs = {
        k: v.to(DEVICE)
        for k, v in inputs.items()
    }

    streamer = TextIteratorStreamer(
        PROCESSOR.tokenizer,
        skip_prompt=True,
    )
    generation_kwargs = dict(
        inputs,
        bad_words_ids=BAD_WORDS_IDS,
        max_length=4096,
        streamer=streamer,
    )
    # Regular generation version
    # generation_kwargs.pop("streamer")
    # generated_ids = MODEL.generate(**generation_kwargs)
    # generated_text = PROCESSOR.batch_decode(
    #     generated_ids,
    #     skip_special_tokens=True
    # )[0]
    # rendered_page = render_webpage(generated_text)
    # return generated_text, rendered_page
    # Token streaming version
    thread = Thread(
        target=MODEL.generate,
        kwargs=generation_kwargs,
    )
    thread.start()
    generated_text = ""
    for new_text in streamer:
        if "</s>" in new_text:
            new_text = new_text.replace("</s>", "")
            rendered_image = render_webpage(generated_text)
        else:
            rendered_image = None
        generated_text += new_text
        yield generated_text, rendered_image


generated_html = gr.Code(
    label="Extracted HTML",
    elem_id="generated_html",
)
rendered_html = gr.Image(
    label="Rendered HTML",
    show_download_button=False,
    show_share_button=False,
)
# rendered_html = gr.HTML(
#     label="Rendered HTML"
# )


css = """
.gradio-container{max-width: 1000px!important}
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
*{transition: width 0.5s ease, flex-grow 0.5s ease}
"""


with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
    gr.Markdown(
        "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
    )
    with gr.Row(equal_height=True):
        with gr.Column(scale=4, min_width=250) as upload_area:
            imagebox = gr.Image(
                type="pil",
                label="Screenshot to extract",
                visible=True,
                sources=["upload", "clipboard"],
            )
            with gr.Group():
                with gr.Row():
                    submit_btn = gr.Button(
                        value="▶️ Submit", visible=True, min_width=120
                    )
                    clear_btn = gr.ClearButton(
                        [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
                    )
                    regenerate_btn = gr.Button(
                        value="🔄 Regenerate", visible=True, min_width=120
                    )
        with gr.Column(scale=4):
            rendered_html.render()

    with gr.Row():
        generated_html.render()

    with gr.Row():
        template_gallery = gr.Gallery(
            value=IMAGE_GALLERY_PATHS,
            label="Templates Gallery",
            allow_preview=False,
            columns=5,
            elem_id="gallery",
            show_share_button=False,
            height=400,
        )

    gr.on(
        triggers=[
            imagebox.upload,
            submit_btn.click,
            regenerate_btn.click,
        ],
        fn=model_inference,
        inputs=[imagebox],
        outputs=[generated_html, rendered_html],
    )
    regenerate_btn.click(
        fn=model_inference,
        inputs=[imagebox],
        outputs=[generated_html, rendered_html],
    )
    template_gallery.select(
        fn=add_file_gallery,
        inputs=[template_gallery],
        outputs=[imagebox],
    ).success(
        fn=model_inference,
        inputs=[imagebox],
        outputs=[generated_html, rendered_html],
    )
    demo.load()

demo.queue(max_size=40, api_open=False)
demo.launch(max_threads=400)