import spaces
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image

# 获取 Hugging Face 访问令牌
hf_token = os.getenv("HF_API_TOKEN")

# 定义模型名称
vqa_model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
language_model_name = "larry1129/WooWoof_AI_Vision_merged_16bit_3b"

# 全局变量用于缓存模型和分词器
vqa_processor = None
vqa_model = None
language_tokenizer = None
language_model = None

# 初始化看图说话模型
def load_vqa_model():
    global vqa_processor, vqa_model
    if vqa_processor is None or vqa_model is None:
        vqa_processor = LlavaNextProcessor.from_pretrained(vqa_model_name, use_auth_token=hf_token)
        vqa_model = LlavaNextForConditionalGeneration.from_pretrained(
            vqa_model_name,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True
        ).to("cuda:0")
    return vqa_processor, vqa_model

# 初始化纯语言模型
def load_language_model():
    global language_tokenizer, language_model
    if language_tokenizer is None or language_model is None:
        language_tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_auth_token=hf_token)
        language_model = AutoModelForCausalLM.from_pretrained(
            language_model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )
        language_tokenizer.pad_token = language_tokenizer.eos_token
        language_model.config.pad_token_id = language_tokenizer.pad_token_id
        language_model.eval()
    return language_tokenizer, language_model

# 从图片生成描述
# 定义生成响应的函数,并使用 @spaces.GPU 装饰
@spaces.GPU(duration=40)  # 建议将 duration 增加到 120
def generate_image_description(image):
    vqa_processor, vqa_model = load_vqa_model()
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "What is shown in this image?"},
                {"type": "image"},
            ],
        },
    ]
    prompt = vqa_processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = vqa_processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")

    with torch.no_grad():
        output = vqa_model.generate(**inputs, max_new_tokens=100)
    image_description = vqa_processor.decode(output[0], skip_special_tokens=True)
    return image_description

# 使用纯语言模型生成最终回答
# 定义生成响应的函数,并使用 @spaces.GPU 装饰
@spaces.GPU(duration=40)  # 建议将 duration 增加到 120
def generate_language_response(instruction, image_description):
    language_tokenizer, language_model = load_language_model()
    prompt = f"""### Instruction:
{instruction}
### Input:
{image_description}
### Response:
"""
    inputs = language_tokenizer(prompt, return_tensors="pt").to(language_model.device)
    with torch.no_grad():
        outputs = language_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask"),
            max_new_tokens=128,
            temperature=0.7,
            top_p=0.95,
            do_sample=True,
        )
    response = language_tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("### Response:")[-1].strip()
    return response

# 整合的 Gradio 接口函数
def process_image_and_text(image, instruction):
    image_description = generate_image_description(image)
    final_response = generate_language_response(instruction, image_description)
    return f"图片描述: {image_description}\n\n最终回答: {final_response}"

# 创建 Gradio 界面
iface = gr.Interface(
    fn=process_image_and_text,
    inputs=[
        gr.Image(type="pil", label="上传图片"),
        gr.Textbox(lines=2, placeholder="Instruction", label="Instruction")
    ],
    outputs="text",
    title="WooWoof AI - 图片和文本交互",
    description="输入图片并添加指令,生成基于图片描述的回答。",
    allow_flagging="never"
)

# 启动 Gradio 接口
iface.launch()