import torch
from torchvision import transforms

import numpy as np
from skimage.color import rgb2lab, lab2rgb
import skimage.transform
from PIL import Image

import os
from tqdm import tqdm
from moviepy.editor import VideoFileClip, AudioFileClip
from moviepy.tools import cvsecs
import cv2

from pdb import set_trace


def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)


SIZE = 256


def get_L(img):
    img = transforms.Resize(
        (SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img)
    img = np.array(img)
    img_lab = rgb2lab(img).astype("float32")
    img_lab = transforms.ToTensor()(img_lab)
    L = img_lab[[0], ...] / 50. - 1.  # Between -1 and 1

    return L


def get_predictions(model, L):
    # model.L = L.to(model.device)
    model.eval()
    with torch.no_grad():
        model.L = L.to(torch.device('cpu'))
        model.forward()
    fake_color = model.fake_color.detach()
    fake_imgs = lab_to_rgb(L, fake_color)

    return fake_imgs


def colorize_img(model, img):
    L = get_L(img)
    L = L[None]  # put in list
    fake_imgs = get_predictions(model, L)
    fake_img = fake_imgs[0]  # get out of list
    resized_fake_img = skimage.transform.resize(
        fake_img, img.size[::-1])  # reshape to original size

    return resized_fake_img


def valid_start_end(duration, start_input, end_input):
    start = start_input
    end = end_input
    if start == '':
        start = 0
    if end == '':
        end = duration

    try:
        start = cvsecs(start)
        end = cvsecs(end)
    except BaseException:
        # start, end aren't actual time values.
        raise Exception("Invalid start, end values")

    # make it minimal maximum length
    start = max(start, 0)
    end = min(duration, end)

    # start must be less than end
    if start >= end:
        raise Exception("Start must be before end.")

    return start, end


def colorize_vid(path_input, model, fps, start_input, end_input):

    original_video = VideoFileClip(path_input)

    # validate start, end
    start, end = valid_start_end(
        original_video.duration, start_input, end_input)

    input_video = original_video.subclip(start, end)

    if isinstance(fps, int):
        used_fps = fps
        nframes = np.round(fps * input_video.duration)
    else:
        used_fps = input_video.fps
        nframes = input_video.reader.nframes
    print(
        f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.")

    frames = input_video.iter_frames(fps=used_fps)

    # create tmp path that is same as input path but with '_tmp.[suffix]'
    base_path, suffix = os.path.splitext(path_input)
    path_video_tmp = base_path + "_tmp" + suffix

    # create video writer for output
    size = input_video.size
    out = cv2.VideoWriter(
        path_video_tmp,
        cv2.VideoWriter_fourcc(
            *'mp4v'),
        used_fps,
        size)
    # out = cv2.VideoWriter(path_video_tmp, cv2.VideoWriter_fourcc(*'DIVX'), used_fps, size)

    for frame in tqdm(frames, total=nframes):
        # get colorized frame
        color_frame = colorize_img(model, Image.fromarray(frame))

        if color_frame.max() <= 1:
            color_frame = (color_frame * 255).astype(np.uint8)

        color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB)
        out.write(color_frame)
    out.release()

    # create output path that is same as input path but with '_out.[suffix]'
    path_output = base_path + "_out" + suffix

    # for some reason, subclip doesn't save audio. so make tmp audio file
    path_audio_tmp = base_path + "audio_tmp.mp3"
    input_video.audio.write_audiofile(path_audio_tmp, logger=None)
    input_audio = AudioFileClip(path_audio_tmp)

    output_video = VideoFileClip(path_video_tmp)
    output_video = output_video.set_audio(input_audio)
    output_video.write_videofile(path_output, logger=None)

    os.remove(path_video_tmp)
    os.remove(path_audio_tmp)

    print("Done.")
    return path_output