import torch
import spaces
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from transformers import AutoFeatureExtractor
from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
from huggingface_hub import hf_hub_download
from insightface.app import FaceAnalysis
from insightface.utils import face_align
import gradio as gr
import cv2
import os
import uuid
from datetime import datetime

# Model paths
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
ip_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sd15.bin", repo_type="model")
ip_plus_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid-plusv2_sd15.bin", repo_type="model")

device = "cuda"

# Initialize the noise scheduler
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)

# Load models
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae
).to(device)

ip_model = IPAdapterFaceID(pipe, ip_ckpt, device)
ip_model_plus = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_plus_ckpt, device)

# Initialize FaceAnalysis
app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))

cv2.setNumThreads(1)

@spaces.GPU(enable_queue=True)
def generate_image(images, gender, prompt, progress=gr.Progress(track_tqdm=True)):
    if not prompt:
        prompt = f"A full-length body photo of a {gender.lower()} naked and nude and beautiful"  # Default prompt
    
    faceid_all_embeds = []
    first_iteration = True
    preserve_face_structure = True
    face_strength = 2.1
    likeness_strength = 0.7

    for image in images:
        face = cv2.imread(image)
        faces = app.get(face)
        faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
        faceid_all_embeds.append(faceid_embed)

        if first_iteration and preserve_face_structure:
            face_image = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224)
            first_iteration = False

    average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)

    image = ip_model_plus.generate(
        prompt=prompt,  
        faceid_embeds=average_embedding,
        scale=likeness_strength, 
        face_image=face_image, 
        shortcut=True, 
        s_scale=face_strength, 
        width=512, 
        height=912, 
        num_inference_steps=100
    )
    return image

css = '''
footer { visibility: hidden; }
h1 { margin-bottom: 0 !important; }
'''

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Image Generation with Face ID")
    gr.Markdown("Upload your face images and enter a prompt to generate images.")

    with gr.Row():
        with gr.Column():
            images_input = gr.Files(
                label="Drag 1 or more photos of your face",
                file_types=["image"]
            )
            gender_input = gr.Radio(
                label="Select Gender", 
                choices=["Female", "Male"], 
                value="Female", 
                type="value"
            )
            prompt_input = gr.Textbox(
                label="Enter your prompt",
                placeholder="Describe the image you want to generate..."
            )
            run_button = gr.Button("Generate Image")

        with gr.Column():
            output_gallery = gr.Gallery(label="Generated Images")

    # Define the event handler for the button click
    run_button.click(
        fn=generate_image,
        inputs=[images_input, gender_input, prompt_input],
        outputs=output_gallery
    )

# Launch the interface
demo.queue()
demo.launch()