import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from PIL import Image
import numpy as np
import os
import time
import spaces

# Load model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
                                             language_config=language_config,
                                             trust_remote_code=True)
if torch.cuda.is_available():
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
    vl_gpt = vl_gpt.to(torch.float16)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

@torch.inference_mode()
@spaces.GPU(duration=120) 
def multimodal_understanding(image, question, seed, top_p, temperature):
    # Clear CUDA cache before generating
    torch.cuda.empty_cache()
    
    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    
    conversation = [
        {
            "role": "<|User|>",
            "content": f"<image_placeholder>\n{question}",
            "images": [image],
        },
        {"role": "<|Assistant|>", "content": ""},
    ]
    
    pil_images = [Image.fromarray(image)]
    prepare_inputs = vl_chat_processor(
        conversations=conversation, images=pil_images, force_batchify=True
    ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
    
    inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
    
    outputs = vl_gpt.language_model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=prepare_inputs.attention_mask,
        pad_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=4000,
        do_sample=False if temperature == 0 else True,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
    )
    
    answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
    return answer

def generate(input_ids,
             width,
             height,
             temperature: float = 1,
             parallel_size: int = 5,
             cfg_weight: float = 5,
             image_token_num_per_image: int = 576,
             patch_size: int = 16):
    # Clear CUDA cache before generating
    torch.cuda.empty_cache()
    
    tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
    for i in range(parallel_size * 2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id
    inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)

    pkv = None
    for i in range(image_token_num_per_image):
        with torch.no_grad():
            outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
                                                use_cache=True,
                                                past_key_values=pkv)
            pkv = outputs.past_key_values
            hidden_states = outputs.last_hidden_state
            logits = vl_gpt.gen_head(hidden_states[:, -1, :])
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]
            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)
            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)

            img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

    patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
                                                 shape=[parallel_size, 8, width // patch_size, height // patch_size])

    return generated_tokens.to(dtype=torch.int), patches

def unpack(dec, width, height, parallel_size=5):
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    return visual_img

@torch.inference_mode()
@spaces.GPU(duration=120)  # Specify a duration to avoid timeout
def generate_image(prompt,
                   seed=None,
                   guidance=5,
                   t2i_temperature=1.0):
    # Clear CUDA cache and avoid tracking gradients
    torch.cuda.empty_cache()
    # Set the seed for reproducible results
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
    width = 384
    height = 384
    parallel_size = 5
    
    with torch.no_grad():
        messages = [{'role': '<|User|>', 'content': prompt},
                    {'role': '<|Assistant|>', 'content': ''}]
        text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
                                                                   sft_format=vl_chat_processor.sft_format,
                                                                   system_prompt='')
        text = text + vl_chat_processor.image_start_tag
        
        input_ids = torch.LongTensor(tokenizer.encode(text))
        output, patches = generate(input_ids,
                                   width // 16 * 16,
                                   height // 16 * 16,
                                   cfg_weight=guidance,
                                   parallel_size=parallel_size,
                                   temperature=t2i_temperature)
        images = unpack(patches,
                        width // 16 * 16,
                        height // 16 * 16,
                        parallel_size=parallel_size)

        return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]



# Custom CSS as a string
custom_css = """
    .gradio-container {
        font-family: 'Inter', -apple-system, sans-serif;
    }
    .image-preview {
        min-height: 300px;
        max-height: 500px;
        width: 100%;
        object-fit: contain;
        border-radius: 8px;
        border: 2px solid #eee;
    }
    .tab-nav {
        background: white;
        padding: 1rem;
        border-radius: 8px;
        box-shadow: 0 2px 4px rgba(0,0,0,0.05);
    }
    .examples-table {
        font-size: 0.9rem;
    }
    .gr-button.gr-button-lg {
        padding: 12px 24px;
        font-size: 1.1rem;
    }
    .gr-input, .gr-select {
        border-radius: 6px;
    }
    .gr-form {
        background: white;
        padding: 20px;
        border-radius: 12px;
        box-shadow: 0 4px 6px rgba(0,0,0,0.05);
    }
    .gr-panel {
        border: none;
        background: transparent;
    }
    .footer {
        text-align: center;
        margin-top: 2rem;
        padding: 1rem;
        color: #666;
    }
"""

# Gradio interface with improved UI
with gr.Blocks(
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo"),
    css=custom_css
) as demo:
    gr.Markdown(
        """
        # Deepseek Multimodal
        ### Advanced AI for Visual Understanding and Generation
        
        This powerful multimodal AI system combines:
        * **Visual Analysis**: Advanced image understanding and medical image interpretation
        * **Creative Generation**: High-quality image generation from text descriptions
        * **Interactive Chat**: Natural conversation about visual content
        """
    )
    
    with gr.Tabs():
        # Visual Chat Tab
        with gr.Tab("Visual Understanding"):
            with gr.Row(equal_height=True):
                with gr.Column(scale=1):
                    image_input = gr.Image(
                        label="Upload Image",
                        type="numpy",
                        elem_classes="image-preview"
                    )
                    
                with gr.Column(scale=1):
                    question_input = gr.Textbox(
                        label="Question or Analysis Request",
                        placeholder="Ask a question about the image or request detailed analysis...",
                        lines=3
                    )
                    with gr.Row():
                        und_seed_input = gr.Number(
                            label="Seed",
                            precision=0,
                            value=42,
                            container=False
                        )
                        top_p = gr.Slider(
                            minimum=0,
                            maximum=1,
                            value=0.95,
                            step=0.05,
                            label="Top-p",
                            container=False
                        )
                        temperature = gr.Slider(
                            minimum=0,
                            maximum=1,
                            value=0.1,
                            step=0.05,
                            label="Temperature",
                            container=False
                        )
                    
                    understanding_button = gr.Button(
                        "Analyze Image",
                        variant="primary"
                    )
            
            understanding_output = gr.Textbox(
                label="Analysis Results",
                lines=10,
                show_copy_button=True
            )
            
            with gr.Accordion("Medical Analysis Examples", open=False):
                gr.Examples(
                    examples=[
                        [
                            """You are an AI assistant trained to analyze medical images...""",
                            "fundus.webp",
                        ],
                    ],
                    inputs=[question_input, image_input],
                )

        # Image Generation Tab
        with gr.Tab("Image Generation"):
            with gr.Column():
                prompt_input = gr.Textbox(
                    label="Image Description",
                    placeholder="Describe the image you want to create in detail...",
                    lines=3
                )
                
                with gr.Row():
                    cfg_weight_input = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=5,
                        step=0.5,
                        label="Guidance Scale",
                        info="Higher values create images that more closely match your prompt"
                    )
                    t2i_temperature = gr.Slider(
                        minimum=0,
                        maximum=1,
                        value=1.0,
                        step=0.05,
                        label="Temperature",
                        info="Controls randomness in generation"
                    )
                    seed_input = gr.Number(
                        label="Seed (Optional)",
                        precision=0,
                        value=12345,
                        info="Set for reproducible results"
                    )
                
                generation_button = gr.Button(
                    "Generate Images",
                    variant="primary"
                )
                
                image_output = gr.Gallery(
                    label="Generated Images",
                    columns=3,
                    rows=2,
                    height=500,
                    object_fit="contain"
                )
                
                with gr.Accordion("Generation Examples", open=False):
                    gr.Examples(
                        examples=[
                            "Master shifu racoon wearing drip attire as a street gangster.",
                            "The face of a beautiful girl",
                            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
                            "A glass of red wine on a reflective surface.",
                            "A cute and adorable baby fox with big brown eyes...",
                        ],
                        inputs=prompt_input,
                    )

    # Connect components
    understanding_button.click(
        multimodal_understanding,
        inputs=[image_input, question_input, und_seed_input, top_p, temperature],
        outputs=understanding_output
    )
    
    generation_button.click(
        fn=generate_image,
        inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
        outputs=image_output
    )

# Launch the demo
if __name__ == "__main__":
    demo.launch(share=True)