from typing import List

import cv2
import torch
import numpy as np
from tqdm import tqdm
import supervision as sv
import torch.nn.functional as F
from transformers import AutoModel
from sklearn.decomposition import PCA
from torchvision import transforms as T
from sklearn.preprocessing import MinMaxScaler


def load_video_frames(video_path: str) -> List[np.ndarray]:
    frames = []
    for frame in tqdm(sv.get_video_frames_generator(source_path=video_path), unit=" frames"):
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    return frames

def preprocess(image: np.ndarray, n_patches: int, device: str, patch_size: int = 14) -> torch.Tensor:
    IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

    transform = T.Compose([
        T.Resize((n_patches * patch_size, n_patches * patch_size)),
        T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    ])

    img = torch.from_numpy(image).type(torch.float).permute(2, 0, 1) / 255
    img_tensor = transform(img).unsqueeze(0).to(device)

    return img_tensor


def process_video(
    model: AutoModel,
    video: str | List[np.ndarray],
    is_larger: bool = True,
    batch_size: int = 4,
    threshold: float = 0.5,
    n_patches: int = 40,
    interpolate: bool = False,
    device: str = "cpu"
) -> List[np.ndarray]:
    # NP = N_PATCHES
    # P = PATCH_SIZE
    if isinstance(video, str):
        frames = load_video_frames(video)
    else:
        frames = video
    patch_size = model.config.patch_size

    original_height = frames[0].shape[0] # C, H, W
    original_width = frames[0].shape[1] # C, H, W

    final_frames = []
    pca = PCA(n_components=3)
    scaler = MinMaxScaler(clip=True)

    for i in range(len(frames)//batch_size):
        batch = frames[i*batch_size:batch_size*(i+1)]
        pixel_values = [
            preprocess(f, n_patches, device, patch_size).squeeze(0) for f in batch
        ]
        pixel_values = torch.stack(pixel_values) # B, C, NP * P, NP * P

        with torch.no_grad():
            out = model(pixel_values=pixel_values)

        features = out.last_hidden_state[:, 1:] # B, P * P, HIDDEN_DIM
        features = features.cpu().numpy()
        features = features.reshape(batch_size * n_patches * n_patches, -1) # B * P * P, HIDDEN_DIM

        pca_features = pca.fit_transform(features)
        pca_features = scaler.fit_transform(pca_features)

        if is_larger:
            pca_features_bg = pca_features[:, 0] > threshold
        else:
            pca_features_bg = pca_features[:, 0] < threshold


        pca_features_fg = ~pca_features_bg

        pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])

        pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)

        pca_features_rgb = np.zeros((batch_size * n_patches * n_patches, 3))
        pca_features_rgb[pca_features_bg] = 0
        pca_features_rgb[pca_features_fg] = pca_features_fg_seg
        pca_features_rgb = pca_features_rgb.reshape(batch_size, n_patches, n_patches, 3)

        if interpolate:
            # transformed into torch tensor
            pca_features_rgb = torch.from_numpy(pca_features_rgb) # B, P, P, 3
            # reshaped to B, C, P, P
            pca_features_rgb = pca_features_rgb.permute(0, 3, 1, 2)
            # interpolate to B, C, H, W
            # reshaped to B, H, W, C
            # unbind to a list of len B with np.ndarray of shape H, W, C
            pca_features_rgb = F.interpolate(
                pca_features_rgb,
                size=(original_height, original_width),
                mode='bilinear',
                align_corners=False
            ).permute(0, 2, 3, 1).unbind(0)
            # Fixing range to np.uint8
        else:
            pca_features_rgb = [f for f in pca_features_rgb]
        # Adding to final_frames list
        final_frames.extend(pca_features_rgb)

    return final_frames


def create_video_from_frames_rgb(
    frame_list: List[np.ndarray], 
    output_filename: str = "animation.mp4", 
    fps: int = 15
) -> str:
    # Get the shape of the frames to determine video dimensions
    frame_height, frame_width, _ = frame_list[0].shape

    # Define the codec and create a VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # You can change the codec as needed
    out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_width, frame_height))

    for frame in frame_list:
        # Write the frame to the video file
        out.write(np.uint8(frame*255))

    # Release the VideoWriter object
    out.release()

    return output_filename