# Copyright (c) 2023-2024, Qi Zuo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from PIL import Image
import numpy as np
import gradio as gr
import base64
import spaces
import subprocess
import os

# def install_cuda_toolkit():
# #     CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
# #     # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
# #     CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
# #     subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
# #     subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
# #     subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])

#     os.environ["CUDA_HOME"] = "/usr/local/cuda"
#     os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
#     os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
#         os.environ["CUDA_HOME"],
#         "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
#     )
#     # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
#     os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"

# install_cuda_toolkit()

def launch_pretrained():
    from huggingface_hub import snapshot_download, hf_hub_download
    hf_hub_download(repo_id="DyrusQZ/LHM_Runtime", repo_type='model', filename='assets.tar', local_dir="./")
    os.system("tar -xvf assets.tar && rm assets.tar")
    hf_hub_download(repo_id="DyrusQZ/LHM_Runtime", repo_type='model', filename='LHM-0.5B.tar', local_dir="./")
    os.system("tar -xvf LHM-0.5B.tar && rm LHM-0.5B.tar")
    hf_hub_download(repo_id="DyrusQZ/LHM_Runtime", repo_type='model', filename='LHM_prior_model.tar', local_dir="./")
    os.system("tar -xvf LHM_prior_model.tar && rm LHM_prior_model.tar")

def launch_env_not_compile_with_cuda():
    os.system("pip install chumpy")
    os.system("pip uninstall -y basicsr")
    os.system("pip install git+https://github.com/hitsz-zuoqi/BasicSR/")
    # os.system("pip install -e ./third_party/sam2")
    os.system("pip install numpy==1.23.0")
    # os.system("pip install git+https://github.com/hitsz-zuoqi/sam2/")
    # os.system("pip install git+https://github.com/ashawkey/diff-gaussian-rasterization/")
    # os.system("pip install git+https://github.com/camenduru/simple-knn/")
    os.system("pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html")

# def launch_env_compile_with_cuda():
#     # simple_knn
#     os.system("wget oss://virutalbuy-public/share/aigc3d/data/for_lingteng/LHM/simple_knn.zip && wget oss://virutalbuy-public/share/aigc3d/data/for_lingteng/LHM/simple_knn-0.0.0.dist-info.zip")
#     os.system("unzip simple_knn.zip && unzip simple_knn-0.0.0.dist-info.zip")
#     os.system("mv simple_knn /usr/local/lib/python3.10/site-packages/")
#     os.system("mv simple_knn-0.0.0.dist-info /usr/local/lib/python3.10/site-packages/")

#     # diff_gaussian
#     os.system("wget oss://virutalbuy-public/share/aigc3d/data/for_lingteng/LHM/diff_gaussian_rasterization.zip && wget oss://virutalbuy-public/share/aigc3d/data/for_lingteng/LHM/diff_gaussian_rasterization-0.0.0.dist-info.zip")
#     os.system("unzip diff_gaussian_rasterization.zip && unzip diff_gaussian_rasterization-0.0.0.dist-info.zip")
#     os.system("mv diff_gaussian_rasterization /usr/local/lib/python3.10/site-packages/")
#     os.system("mv diff_gaussian_rasterization-0.0.0.dist-info /usr/local/lib/python3.10/site-packages/")

#     # pytorch3d
#     os.system("wget oss://virutalbuy-public/share/aigc3d/data/for_lingteng/LHM/pytorch3d.zip && wget oss://virutalbuy-public/share/aigc3d/data/for_lingteng/LHM/pytorch3d-0.7.8.dist-info.zip")
#     os.system("unzip pytorch3d.zip && unzip pytorch3d-0.7.8.dist-info.zip")
#     os.system("mv pytorch3d /usr/local/lib/python3.10/site-packages/")
#     os.system("mv pytorch3d-0.7.8.dist-info /usr/local/lib/python3.10/site-packages/")


# launch_env_compile_with_cuda()

def assert_input_image(input_image):
    if input_image is None:
        raise gr.Error("No image selected or uploaded!")

def prepare_working_dir():
    import tempfile
    working_dir = tempfile.TemporaryDirectory()
    return working_dir

def init_preprocessor():
    from LHM.utils.preprocess import Preprocessor
    global preprocessor
    preprocessor = Preprocessor()

def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool, working_dir):
    image_raw = os.path.join(working_dir.name, "raw.png")
    with Image.fromarray(image_in) as img:
        img.save(image_raw)
    image_out = os.path.join(working_dir.name, "rembg.png")
    success = preprocessor.preprocess(image_path=image_raw, save_path=image_out, rmbg=remove_bg, recenter=recenter)
    assert success, f"Failed under preprocess_fn!"
    return image_out

def get_image_base64(path):
    with open(path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode()
    return f"data:image/png;base64,{encoded_string}"


def demo_lhm(infer_impl):

    def core_fn(image: str, video_params, working_dir):
        image_raw = os.path.join(working_dir.name, "raw.png")
        with Image.fromarray(image) as img:
            img.save(image_raw)
        
        base_vid = os.path.basename(video_params).split("_")[0]
        smplx_params_dir = os.path.join("./assets/sample_motion", base_vid, "smplx_params")

        dump_video_path = os.path.join(working_dir.name, "output.mp4")
        dump_image_path = os.path.join(working_dir.name, "output.png")

        status = spaces.GPU(infer_impl(
            gradio_demo_image=image_raw, 
            gradio_motion_file=smplx_params_dir, 
            gradio_masked_image=dump_image_path, 
            gradio_video_save_path=dump_video_path
        ))
        if status:
            return dump_image_path, dump_video_path
        else:
            return None, None

    _TITLE = '''LHM: Large Animatable Human Model'''

    _DESCRIPTION = '''
        <strong>Reconstruct a human avatar in 0.2 seconds with A100!</strong>
    '''

    with gr.Blocks(analytics_enabled=False) as demo:

        # </div>
        logo_url = "./assets/rgba_logo_new.png"
        logo_base64 = get_image_base64(logo_url)
        gr.HTML(
            f"""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
            <div>
                <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Animatable Human Model </h1>
            </div>
            </div>
            """
        )
        gr.HTML(
            """<p><h4 style="color: red;"> Notes: Please input full-body image in case of detection errors.</h4></p>"""
        )

        # DISPLAY
        with gr.Row():

            with gr.Column(variant='panel', scale=1):
                with gr.Tabs(elem_id="openlrm_input_image"):
                    with gr.TabItem('Input Image'):
                        with gr.Row():
                            input_image = gr.Image(label="Input Image", image_mode="RGBA", height=480, width=270, sources="upload", type="numpy", elem_id="content_image")
                # EXAMPLES
                with gr.Row():
                    examples = [
                        ['assets/sample_input/joker.jpg'],
                        ['assets/sample_input/anime.png'],
                        ['assets/sample_input/basket.png'],
                        ['assets/sample_input/ai_woman1.JPG'],
                        ['assets/sample_input/anime2.JPG'],
                        ['assets/sample_input/anime3.JPG'],
                        ['assets/sample_input/boy1.png'],
                        ['assets/sample_input/choplin.jpg'],
                        ['assets/sample_input/eins.JPG'],
                        ['assets/sample_input/girl1.png'],
                        ['assets/sample_input/girl2.png'],
                        ['assets/sample_input/robot.jpg'],
                    ]
                    gr.Examples(
                        examples=examples,
                        inputs=[input_image], 
                        examples_per_page=20,
                    )

            with gr.Column():
                with gr.Tabs(elem_id="openlrm_input_video"):
                    with gr.TabItem('Input Video'):
                        with gr.Row():
                            video_input = gr.Video(label="Input Video",height=480, width=270, interactive=False)

                examples = [
                    # './assets/sample_motion/danaotiangong/danaotiangong_origin.mp4',
                    './assets/sample_motion/ex5/ex5_origin.mp4',
                    './assets/sample_motion/girl2/girl2_origin.mp4',
                    './assets/sample_motion/jntm/jntm_origin.mp4',
                    './assets/sample_motion/mimo1/mimo1_origin.mp4',
                    './assets/sample_motion/mimo2/mimo2_origin.mp4',
                    './assets/sample_motion/mimo4/mimo4_origin.mp4',
                    './assets/sample_motion/mimo5/mimo5_origin.mp4',
                    './assets/sample_motion/mimo6/mimo6_origin.mp4',
                    './assets/sample_motion/nezha/nezha_origin.mp4',
                    './assets/sample_motion/taiji/taiji_origin.mp4'
                ]

                gr.Examples(
                    examples=examples,
                    inputs=[video_input],
                    examples_per_page=20,
                )
            with gr.Column(variant='panel', scale=1):
                with gr.Tabs(elem_id="openlrm_processed_image"):
                    with gr.TabItem('Processed Image'):
                        with gr.Row():
                            processed_image = gr.Image(label="Processed Image", image_mode="RGBA", type="filepath", elem_id="processed_image", height=480, width=270, interactive=False)

            with gr.Column(variant='panel', scale=1):
                with gr.Tabs(elem_id="openlrm_render_video"):
                    with gr.TabItem('Rendered Video'):
                        with gr.Row():
                            output_video = gr.Video(label="Rendered Video", format="mp4", height=480, width=270, autoplay=True)

        # SETTING
        with gr.Row():
            with gr.Column(variant='panel', scale=1):
                submit = gr.Button('Generate', elem_id="openlrm_generate", variant='primary')


        working_dir = gr.State()
        submit.click(
            fn=assert_input_image,
            inputs=[input_image],
            queue=False,
        ).success(
            fn=prepare_working_dir,
            outputs=[working_dir],
            queue=False,
        ).success(
            fn=core_fn,
            inputs=[input_image, video_input, working_dir], # video_params refer to smpl dir
            outputs=[processed_image, output_video],
        )

        demo.queue()
        demo.launch()


def launch_gradio_app():

    os.environ.update({
        "APP_ENABLED": "1",
        "APP_MODEL_NAME": "./exps/releases/video_human_benchmark/human-lrm-500M/step_060000/",
        "APP_INFER": "./configs/inference/human-lrm-500M.yaml",
        "APP_TYPE": "infer.human_lrm",
        "NUMBA_THREADING_LAYER": 'omp',
    })

    from LHM.runners import REGISTRY_RUNNERS
    RunnerClass = REGISTRY_RUNNERS[os.getenv("APP_TYPE")]
    with RunnerClass() as runner:
        runner.pose_estimator.device = torch.device('cuda')
        runner.pose_estimator.mhmr_model.cuda()
        demo_lhm(infer_impl=runner.infer)


if __name__ == '__main__':
    launch_pretrained()
    launch_env_not_compile_with_cuda()
    # launch_gradio_app()