import os
import spaces
try:
  token =os.environ['HF_TOKEN']
except:
  print("paste your hf token here!")
  token = "hf_xxxxxxxxxxxxxxxxxxx"
  os.environ['HF_TOKEN'] = token
import torch
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes

from transformers import AutoTokenizer, AutoModel

# ========================================
#             Model Initialization
# ========================================

tokenizer =  AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2_chat_8B_HD',
    trust_remote_code=True,
    use_fast=False,
    token=token)
if torch.cuda.is_available():
  model = AutoModel.from_pretrained(
      'OpenGVLab/InternVideo2_chat_8B_HD',
      torch_dtype=torch.bfloat16,
      trust_remote_code=True).cuda()
else:
  model = AutoModel.from_pretrained(
      'OpenGVLab/InternVideo2_chat_8B_HD',
      torch_dtype=torch.bfloat16,
      trust_remote_code=True)

from decord import VideoReader, cpu
from PIL import Image
import numpy as np
import numpy as np
import decord
from decord import VideoReader, cpu
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import PILToTensor
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
decord.bridge.set_bridge("torch")

# ========================================
#          Define Utils
# ========================================
def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets


def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=4, padding=False):
    decord.bridge.set_bridge("torch")
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    transform = transforms.Compose([
        transforms.Lambda(lambda x: x.float().div(255.0)),
        transforms.Normalize(mean, std)
    ])

    frames = vr.get_batch(frame_indices)
    # frames = torch.from_numpy(frames)
    frames = frames.permute(0, 3, 1, 2)

    if padding:
        frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num)
    else:
        frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num)

    frames = transform(frames)
    # print(frames.shape)
    T_, C, H, W = frames.shape

    sub_img = frames.reshape(
        1, T_, 3, H//resolution, resolution, W//resolution, resolution
    ).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous()

    glb_img = F.interpolate(
        frames.float(), size=(resolution, resolution), mode='bicubic', align_corners=False
    ).to(sub_img.dtype).unsqueeze(0)

    frames = torch.cat([sub_img, glb_img]).unsqueeze(0)

    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
        # " " should be added in the start and end
        msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
        return frames, msg
    else:
        return frames

def HD_transform_padding(frames, image_size=224, hd_num=6):
    def _padding_224(frames):
        _, _, H, W = frames.shape
        tar = int(np.ceil(H / 224) * 224)
        top_padding = (tar - H) // 2
        bottom_padding = tar - H - top_padding
        left_padding = 0
        right_padding = 0

        padded_frames = F.pad(
            frames,
            pad=[left_padding, right_padding, top_padding, bottom_padding],
            mode='constant', value=255
        )
        return padded_frames

    _, _, H, W = frames.shape
    trans = False
    if W < H:
        frames = frames.flip(-2, -1)
        trans = True
        width, height = H, W
    else:
        width, height = W, H

    ratio = width / height
    scale = 1
    while scale * np.ceil(scale / ratio) <= hd_num:
        scale += 1
    scale -= 1
    new_w = int(scale * image_size)
    new_h = int(new_w / ratio)

    resized_frames = F.interpolate(
        frames, size=(new_h, new_w),
        mode='bicubic',
        align_corners=False
    )
    padded_frames = _padding_224(resized_frames)

    if trans:
        padded_frames = padded_frames.flip(-2, -1)

    return padded_frames

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        area = width * height
        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
            elif ratio_diff == best_ratio_diff:
                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                    best_ratio = ratio
        return best_ratio


def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2,1)):
    min_num = 1
    max_num = hd_num
    _, _, orig_height, orig_width = frames.shape
    aspect_ratio = orig_width / orig_height

    # calculate the existing video aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    if fix_ratio:
        target_aspect_ratio = fix_ratio
    else:
        target_aspect_ratio = find_closest_aspect_ratio(
            aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the frames
    resized_frame = F.interpolate(
        frames, size=(target_height, target_width),
        mode='bicubic', align_corners=False
    )
    return resized_frame

# ========================================
#             Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
    if chat_state is not None:
        chat_state = []
    if img_list is not None:
        img_list = None
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list

def upload_img( gr_video, num_segments, hd_num, padding):
    img_list = []
    if gr_video is None:
        return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'),  None
    if gr_video:
        video_tensor, msg = load_video(gr_video, num_segments=num_segments, return_msg=True, resolution=224, hd_num=hd_num, padding=padding)
        video_tensor = video_tensor.to(model.device)
        return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), video_tensor
    # if gr_img:
    #     llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list)
    #     return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False)

def clear_():
    return [], []

def gradio_ask(user_message, chatbot):
    if len(user_message) == 0:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot

@spaces.GPU
def gradio_answer(chatbot, sys_prompt, user_prompt, video_tensor, chat_state, num_beams, temperature, do_sample=False):
    video_tensor = video_tensor.to(model.device)
    response, chat_state = model.chat(tokenizer,
                                        sys_prompt,
                                        user_prompt,
                                        media_type='video',
                                        media_tensor=video_tensor,
                                        chat_history= chat_state,
                                        return_history=True,
                                        generation_config={
                                            "num_beams": num_beams,
                                            "temperature": temperature,
                                            "do_sample": do_sample})
    print(response)
    chatbot[-1][1] = response
    return chatbot, chat_state


class OpenGVLab(gr.themes.base.Base):
    def __init__(
        self,
        *,
        primary_hue=colors.blue,
        secondary_hue=colors.sky,
        neutral_hue=colors.gray,
        spacing_size=sizes.spacing_md,
        radius_size=sizes.radius_sm,
        text_size=sizes.text_md,
        font=(
            fonts.GoogleFont("Noto Sans"),
            "ui-sans-serif",
            "sans-serif",
        ),
        font_mono=(
            fonts.GoogleFont("IBM Plex Mono"),
            "ui-monospace",
            "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        super().set(
            body_background_fill="*neutral_50",
        )


gvlabtheme = OpenGVLab(primary_hue=colors.blue,
        secondary_hue=colors.sky,
        neutral_hue=colors.gray,
        spacing_size=sizes.spacing_md,
        radius_size=sizes.radius_sm,
        text_size=sizes.text_md,
        )

title = """<h1 align="center"><a href="https://github.com/OpenGVLab/Ask-Anything"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="Ask-Anything" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>"""
description ="""
        VideoChat2 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/Ask-Anything'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
        """
SYS_PROMPT =""

with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=0.5, visible=True) as video_upload:
            with gr.Column(elem_id="image", scale=0.5) as img_part:
                # with gr.Tab("Video", elem_id='video_tab'):
                up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload")
                # with gr.Tab("Image", elem_id='image_tab'):
                #     up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload")
            upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
            restart = gr.Button("Restart")
            sys_prompt = gr.State(f"{SYS_PROMPT}")

            num_beams = gr.Slider(
                minimum=1,
                maximum=10,
                value=1,
                step=1,
                interactive=True,
                                                                 label="beam search numbers)",
            )

            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=1.0,
                step=0.1,
                interactive=True,label="Temperature",
            )

            num_segments = gr.Slider(
                minimum=8,
                maximum=64,
                value=8,
                step=1,
                interactive=True,
                label="Input Frames",
            )

            resolution = gr.Slider(
                minimum=224,
                maximum=224,
                value=224,
                step=1,
                interactive=True,
                label="Vision encoder resolution",
            )

            hd_num = gr.Slider(
                minimum=1,
                maximum=10,
                value=4,
                step=1,
                interactive=True,
                label="HD num",
            )

            padding = gr.Checkbox(
                label="padding",
                info=""
            )

        with gr.Column(visible=True)  as input_raws:
            chat_state = gr.State([])
            img_list = gr.State()
            chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
            with gr.Row():
                with gr.Column(scale=0.7):
                    text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False)
                with gr.Column(scale=0.15, min_width=0):
                    run = gr.Button("💭Send")
                with gr.Column(scale=0.15, min_width=0):
                    clear = gr.Button("🔄Clear️")

    upload_button.click(upload_img, [ up_video, num_segments, hd_num, padding], [ up_video, text_input, upload_button, img_list])

    text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
        gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state]
    )
    run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
        gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state]
    )
    run.click(lambda: "", None, text_input)
    clear.click(clear_, None, [chatbot, chat_state])
    restart.click(gradio_reset, [chat_state, img_list], [chatbot,  up_video, text_input, upload_button, chat_state, img_list], queue=False)

demo.launch()