import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
import spaces
import tensorrt
import multiprocessing as mp
import pickle
mp.set_start_method('spawn', force=True)  


repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"  
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
    for file in files:
        file_path = os.path.join(root, file)
        print(file_path)

result_dir = "results"
def set_upload():
    return "upload"
def set_microphone():
    return "microphone"
def set_tts():
    return "tts"
def create_kd_talker():
    return Inferencer() 

example_folder = "example"
example_choices = ["Example 1", "Example 2", "Example 3"]
example_mapping = {
    "Example 1": {"audio": os.path.join(example_folder, "example1.wav"), "image": os.path.join(example_folder, "example1.png")},
    "Example 2": {"audio": os.path.join(example_folder, "example2.wav"), "image": os.path.join(example_folder, "example2.png")},
    "Example 3": {"audio": os.path.join(example_folder, "example3.wav"), "image": os.path.join(example_folder, "example3.png")},
}

@spaces.GPU
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
    global result_dir
    output_file_path = os.path.join('./downloaded_repo/', 'output.wav')
    if reference_audio_type == 'upload':
        audio_file_pth = upload_reference_audio
    elif reference_audio_type == 'microphone':
        audio_file_pth =  microphone_reference_audio
    tts = TTS('tts_models/multilingual/multi-dataset/your_tts')
    tts.tts_to_file(
        text=prompt,
        file_path=output_file_path,
        speaker_wav=audio_file_pth,
        language="en",
    )
    return gr.Audio(value=output_file_path, type='filepath')

@spaces.GPU
def generate(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t):
    return Inferencer().generate_with_audio_img(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
                                    smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t)


def main():
    if torch.cuda.is_available():
        device = "cuda" 
    else:
        device = "cpu"
    with gr.Blocks(analytics_enabled=False) as interface:
        with gr.Row():
            gr.HTML(
            """
                <div align='center'>
                    <h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
                    <div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
                        <img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
                        <img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
                    </div>
                </div>
            """
            )
            driven_audio_type = gr.Textbox(value="upload", visible=False) 
            reference_audio_type = gr.Textbox(value="upload", visible=False)
        with gr.Row():
            with gr.Column(variant="panel"):
                with gr.Tabs(elem_id="kdtalker_source_image"):
                    with gr.TabItem("Upload image"):
                        source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
                with gr.Tabs(elem_id="kdtalker_driven_audio"):
                    with gr.TabItem("Upload"):
                        upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
                        upload_driven_audio.change(set_upload, outputs=driven_audio_type)
                    with gr.TabItem("TTS"):
                        upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
                        upload_reference_audio.change(set_upload, outputs=reference_audio_type)
                        microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
                        microphone_reference_audio.change(set_microphone, outputs=reference_audio_type)
                        input_text = gr.Textbox(
                            label="Generating audio from text",
                            lines=5,
                            placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
                        )
                        tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary")
                        tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
                        tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
                        tts_button.click(set_tts, outputs=driven_audio_type)
            with gr.Column(variant="panel"):
                gen_video = gr.Video(label="Generated video", format="mp4", width=256)
                with gr.Tabs(elem_id="talker_checkbox"):
                    with gr.TabItem("KDTalker"):
                        smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
                        smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
                        smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
                        smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
                        kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
                        kd_submit.click(
                                fn=generate,
                                inputs=[
                                    upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
                                    smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t
                                ],
                                outputs=[gen_video]
                            )
                    with gr.TabItem("Example"):
                        example_choice = gr.Dropdown(choices=example_choices, label="Choose an example")
                        def load_example(choice):
                            example = example_mapping.get(choice, {})
                            audio_path = example.get("audio", "")
                            image_path = example.get("image", "")
                            return [audio_path, image_path]
                        example_choice.change(
                            fn=load_example, 
                            inputs=[example_choice], 
                            outputs=[upload_driven_audio, source_image]
                        )
                        example_choice.change(set_upload, outputs=driven_audio_type)


    return interface


demo = main()
demo.queue().launch(share=True)