import base64
from io import BytesIO

import torch
from fastapi import Body, FastAPI, File, HTTPException, Query, UploadFile
from PIL import Image
from pydantic import BaseModel
from qwen_vl_utils import process_vision_info
from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    Qwen2VLForConditionalGeneration,
)

app = FastAPI()


# Define request model
class PredictRequest(BaseModel):
    image_base64: str
    prompt: str


# checkpoint = "Qwen/Qwen2-VL-2B-Instruct"
# min_pixels = 256 * 28 * 28
# max_pixels = 1280 * 28 * 28
# processor = AutoProcessor.from_pretrained(
#     checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
# )
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     checkpoint,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     # attn_implementation="flash_attention_2",
# )

checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(
    checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    checkpoint,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # attn_implementation="flash_attention_2",
)


@app.get("/")
def read_root():
    return {"message": "API is live. Use the /predict endpoint."}


# def encode_image(image_path, max_size=(800, 800), quality=85):
#     """
#     Converts an image from a local file path to a Base64-encoded string with optimized size.

#     Args:
#         image_path (str): The path to the image file.
#         max_size (tuple): The maximum width and height of the resized image.
#         quality (int): The compression quality (1-100, higher means better quality but bigger size).

#     Returns:
#         str: Base64-encoded representation of the optimized image.
#     """
#     try:
#         with Image.open(image_path) as img:
#             # Convert to RGB (avoid issues with PNG transparency)
#             img = img.convert("RGB")

#             # Resize while maintaining aspect ratio
#             img.thumbnail(max_size, Image.LANCZOS)

#             # Save to buffer with compression
#             buffer = BytesIO()
#             img.save(
#                 buffer, format="JPEG", quality=quality
#             )  # Save as JPEG to reduce size
#             return base64.b64encode(buffer.getvalue()).decode("utf-8")
#     except Exception as e:
#         print(f"❌ Error encoding image {image_path}: {e}")
#         return None


def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85):
    """
    Converts an image from file data to a Base64-encoded string with optimized size.
    """
    try:
        with Image.open(image_data) as img:
            img = img.convert("RGB")
            img.thumbnail(max_size, Image.LANCZOS)
            buffer = BytesIO()
            img.save(buffer, format="JPEG", quality=quality)
            return base64.b64encode(buffer.getvalue()).decode("utf-8")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error encoding image: {e}")


@app.post("/encode-image/")
async def upload_and_encode_image(file: UploadFile = File(...)):
    """
    Endpoint to upload an image file and return its Base64-encoded representation.
    """
    try:
        image_data = BytesIO(await file.read())
        encoded_string = encode_image(image_data)
        return {"filename": file.filename, "encoded_image": encoded_string}
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid file: {e}")


@app.post("/predict")
def predict(data: PredictRequest):
    """
    Generates a description for an image using the Qwen-2-VL model.

    Args:
        data (any): The encoded image and the prompt to be used.
        prompt (str): The text prompt to guide the model's response.

    Returns:
        str: The generated description of the image.
    """

    # Create the input message structure
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": f"data:image;base64,{data.image_base64}"},
                {"type": "text", "text": data.prompt},
            ],
        }
    ]

    # Prepare inputs for the model
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    # Generate the output
    generated_ids = model.generate(**inputs, max_new_tokens=2056)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    return {"response": output_text[0] if output_text else "No description generated."}


# @app.get("/predict")
# def predict(image_url: str = Query(...), prompt: str = Query(...)):

#     image = encode_image(image_url)

#     messages = [
#         {
#             "role": "system",
#             "content": "You are a helpful assistant with vision abilities.",
#         },
#         {
#             "role": "user",
#             "content": [
#                 {"type": "image", "image": f"data:image;base64,{image}"},
#                 {"type": "text", "text": prompt},
#             ],
#         },
#     ]
#     text = processor.apply_chat_template(
#         messages, tokenize=False, add_generation_prompt=True
#     )
#     image_inputs, video_inputs = process_vision_info(messages)
#     inputs = processor(
#         text=[text],
#         images=image_inputs,
#         videos=video_inputs,
#         padding=True,
#         return_tensors="pt",
#     ).to(model.device)
#     with torch.no_grad():
#         generated_ids = model.generate(**inputs, max_new_tokens=128)
#     generated_ids_trimmed = [
#         out_ids[len(in_ids) :]
#         for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
#     ]
#     output_texts = processor.batch_decode(
#         generated_ids_trimmed,
#         skip_special_tokens=True,
#         clean_up_tokenization_spaces=False,
#     )
#     return {"response": output_texts[0]}