import gradio as gr
import os
from PIL import Image
import subprocess
from gradio_model4dgs import Model4DGS
import numpy
import hashlib
import shlex

import spaces


subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
# subprocess.run(shlex.split("pip install xformers==0.0.23 --no-deps --index-url https://download.pytorch.org/whl/cu118"))

from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors")

js_func = """
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') !== 'light') {
        url.searchParams.set('__theme', 'light');
        window.location.href = url.href;
    }
}
"""

# check if there is a picture uploaded or selected
def check_img_input(control_image):
    if control_image is None:
        raise gr.Error("Please select or upload an input image")

# check if there is a picture uploaded or selected
def check_video_input(image_block: Image.Image):
    img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
    if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
        raise gr.Error("Please generate a video first")


@spaces.GPU(duration=120)
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
    if not os.path.exists('tmp_data'):
        os.makedirs('tmp_data')
    img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
    if preprocess_chk:
        # save image to a designated path
        image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))

        # preprocess image
        print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
        subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
    else:
        image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))

    # stage 1
    subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)

    # return [os.path.join('logs', 'tmp_rgba_model.ply')]
    return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')

@spaces.GPU(duration=180)
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
    img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
    subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
    # stage 2
    subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
    # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
    image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
    # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
    return [image_dir+f'/{t:03d}.ply' for t in range(28)]


if __name__ == "__main__":
    _TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting'''

    _DESCRIPTION = '''
    <div>
    <a style="display:inline-block" href="https://jiawei-ren.github.io/projects/dreamgaussian4d/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
    <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2312.17142"><img src="https://img.shields.io/badge/2312.17142-f9f7f7?logo="></a>
    <a style="display:inline-block; margin-left: .5em" href='https://github.com/jiawei-ren/dreamgaussian4d'><img src='https://img.shields.io/github/stars/jiawei-ren/dreamgaussian4d?style=social'/></a>
    </div>
    We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting. 
    '''
    _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**."

    # load images in 'data' folder as examples
    example_folder = os.path.join(os.path.dirname(__file__), 'data')
    example_fns = os.listdir(example_folder)
    example_fns.sort()
    examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]

    # Compose demo layout & data flow
    with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown('# ' + _TITLE)
        gr.Markdown(_DESCRIPTION)

        # Image-to-3D
        with gr.Row(variant='panel'):
            with gr.Column(scale=4):
                image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')

                # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
                seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
                gr.Markdown(
                    "random seed for video generation.")

                preprocess_chk = gr.Checkbox(True,
                                             label='Preprocess image automatically (remove background and recenter object)')

                gr.Examples(
                    examples=examples_full,  # NOTE: elements must match inputs list!
                    inputs=[image_block],
                    outputs=[image_block],
                    cache_examples=False,
                    label='Examples (click one of the images below to start)',
                    examples_per_page=40
                )
                img_run_btn = gr.Button("Generate Video")
                fourd_run_btn = gr.Button("Generate 4D")
                img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)

            with gr.Column(scale=5):
                obj3d = gr.Video(label="video",height=290)
                obj4d = Model4DGS(label="4D Model", height=500, fps=14)

            img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
                                                                                          inputs=[image_block,
                                                                                                  preprocess_chk,
                                                                                                  seed_slider],
                                                                                          outputs=[
                                                                                              obj3d])
            fourd_run_btn.click(check_video_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])

    # demo.queue().launch(share=True)
    demo.queue(max_size=10)  # <-- Sets up a queue with default parameters
    demo.launch()