"""
File: load_models.py
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
Description: Load pretrained models.
License: MIT License
"""

import math
import numpy as np
import cv2

import torch.nn.functional as F
import torch.nn as nn
import torch
from typing import Optional
from PIL import Image
from ultralytics import YOLO
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)

from transformers import (
    AutoConfig,
    Wav2Vec2Processor,
    AutoTokenizer,
    AutoModel,
    logging,
)

logging.set_verbosity_error()

from app.utils import pth_processing, get_idx_frames_in_windows

# Importing necessary components for the Gradio app
from app.utils import load_model


class ScaledDotProductAttention_MultiHead(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention_MultiHead, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            raise ValueError("Mask is not supported yet")

        # key, query, value shapes: [batch_size, num_heads, seq_len, dim]
        emb_dim = key.shape[-1]

        # Calculate attention weights
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
            emb_dim
        )

        # masking
        if mask is not None:
            raise ValueError("Mask is not supported yet")

        # Softmax
        attention_weights = self.softmax(attention_weights)

        # modify value
        value = torch.matmul(attention_weights, value)
        return value, attention_weights


class PositionWiseFeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout: float = 0.1):
        super().__init__()
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # feed-forward network
        x = self.layer_1(x)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x


class Add_and_Norm(nn.Module):
    def __init__(self, input_dim, dropout: Optional[float] = 0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        if dropout is not None:
            self.dropout = nn.Dropout(dropout)

    def forward(self, x1, residual):
        x = x1
        # apply dropout of needed
        if hasattr(self, "dropout"):
            x = self.dropout(x)
        # add and then norm
        x = x + residual
        x = self.layer_norm(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, dropout: Optional[float] = 0.1):
        super().__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        if input_dim % num_heads != 0:
            raise ValueError("input_dim must be divisible by num_heads")
        self.head_dim = input_dim // num_heads
        self.dropout = dropout

        # initialize weights
        self.query_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
        self.keys_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
        self.values_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
        self.ff_layer_after_concat = nn.Linear(
            self.num_heads * self.head_dim, input_dim, bias=False
        )

        self.attention = ScaledDotProductAttention_MultiHead()

        if self.dropout is not None:
            self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, mask=None):
        # query, keys, values shapes: [batch_size, seq_len, input_dim]
        batch_size, len_query, len_keys, len_values = (
            queries.size(0),
            queries.size(1),
            keys.size(1),
            values.size(1),
        )

        # linear transformation before attention
        queries = (
            self.query_w(queries)
            .view(batch_size, len_query, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )  # [batch_size, num_heads, seq_len, dim]
        keys = (
            self.keys_w(keys)
            .view(batch_size, len_keys, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )  # [batch_size, num_heads, seq_len, dim]
        values = (
            self.values_w(values)
            .view(batch_size, len_values, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )  # [batch_size, num_heads, seq_len, dim]

        # attention itself
        values, attention_weights = self.attention(
            queries, keys, values, mask=mask
        )  # values shape:[batch_size, num_heads, seq_len, dim]

        # concatenation
        out = (
            values.transpose(1, 2)
            .contiguous()
            .view(batch_size, len_values, self.num_heads * self.head_dim)
        )  # [batch_size, seq_len, num_heads * dim = input_dim]
        # go through last linear layer
        out = self.ff_layer_after_concat(out)
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = pe.permute(
            1, 0, 2
        )  # [seq_len, batch_size, embedding_dim] -> [batch_size, seq_len, embedding_dim]
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class TransformerLayer(nn.Module):
    def __init__(
        self,
        input_dim,
        num_heads,
        dropout: Optional[float] = 0.1,
        positional_encoding: bool = True,
    ):
        super(TransformerLayer, self).__init__()
        self.positional_encoding = positional_encoding
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads
        self.dropout = dropout

        # initialize layers
        self.self_attention = MultiHeadAttention(input_dim, num_heads, dropout=dropout)
        self.feed_forward = PositionWiseFeedForward(
            input_dim, input_dim, dropout=dropout
        )
        self.add_norm_after_attention = Add_and_Norm(input_dim, dropout=dropout)
        self.add_norm_after_ff = Add_and_Norm(input_dim, dropout=dropout)

        # calculate positional encoding
        if self.positional_encoding:
            self.positional_encoding = PositionalEncoding(input_dim)

    def forward(self, key, value, query, mask=None):
        # key, value, and query shapes: [batch_size, seq_len, input_dim]
        # positional encoding
        if self.positional_encoding:
            key = self.positional_encoding(key)
            value = self.positional_encoding(value)
            query = self.positional_encoding(query)

        # multi-head attention
        residual = query
        x = self.self_attention(queries=query, keys=key, values=value, mask=mask)
        x = self.add_norm_after_attention(x, residual)

        # feed forward
        residual = x
        x = self.feed_forward(x)
        x = self.add_norm_after_ff(x, residual)

        return x


class SelfTransformer(nn.Module):
    def __init__(self, input_size: int = int(1024), num_heads=1, dropout=0.1):
        super(SelfTransformer, self).__init__()
        self.att = torch.nn.MultiheadAttention(
            input_size, num_heads, dropout, bias=True, batch_first=True
        )
        self.norm1 = nn.LayerNorm(input_size)
        self.fcl = nn.Linear(input_size, input_size)
        self.norm2 = nn.LayerNorm(input_size)

    def forward(self, video):
        represent, _ = self.att(video, video, video)
        represent_norm = self.norm1(video + represent)
        represent_fcl = self.fcl(represent_norm)
        represent = self.norm1(represent_norm + represent_fcl)
        return represent


class SmallClassificationHead(nn.Module):
    """ClassificationHead"""

    def __init__(self, input_size=256, out_emo=6, out_sen=3):
        super(SmallClassificationHead, self).__init__()
        self.fc_emo = nn.Linear(input_size, out_emo)
        self.fc_sen = nn.Linear(input_size, out_sen)

    def forward(self, x):
        x_emo = self.fc_emo(x)
        x_sen = self.fc_sen(x)
        return {"emo": x_emo, "sen": x_sen}


class AudioModelWT(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)

        self.f_size = 1024

        self.tl1 = TransformerLayer(
            input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True
        )
        self.tl2 = TransformerLayer(
            input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True
        )

        self.fc1 = nn.Linear(1024, 1)
        self.dp = nn.Dropout(p=0.5)

        self.selu = nn.SELU()
        self.relu = nn.ReLU()
        self.cl_head = SmallClassificationHead(
            input_size=199, out_emo=config.out_emo, out_sen=config.out_sen
        )

        self.init_weights()

        # freeze conv
        self.freeze_feature_encoder()

    def freeze_feature_encoder(self):
        for param in self.wav2vec2.feature_extractor.conv_layers.parameters():
            param.requires_grad = False

    def forward(self, x, with_features=False):
        outputs = self.wav2vec2(x)

        x = self.tl1(outputs[0], outputs[0], outputs[0])
        x = self.selu(x)

        features = self.tl2(x, x, x)
        x = self.selu(features)

        x = self.fc1(x)
        x = self.relu(x)
        x = self.dp(x)

        x = x.view(x.size(0), -1)

        if with_features:
            return self.cl_head(x), features
        else:
            return self.cl_head(x)


class AudioFeatureExtractor:
    def __init__(
        self,
        checkpoint_url: str,
        folder_path: str,
        device: torch.device,
        sr: int = 16000,
        win_max_length: int = 4,
        with_features: bool = True,
    ) -> None:
        """
        Args:
            sr (int, optional): Sample rate of audio. Defaults to 16000.
            win_max_length (int, optional): Max length of window. Defaults to 4.
            with_features (bool, optional): Extract features or not
        """
        self.device = device
        self.sr = sr
        self.win_max_length = win_max_length
        self.with_features = with_features

        checkpoint_path = load_model(checkpoint_url, folder_path)

        model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
        model_config = AutoConfig.from_pretrained(model_name)

        model_config.out_emo = 7
        model_config.out_sen = 3
        model_config.context_length = 199

        self.processor = Wav2Vec2Processor.from_pretrained(model_name)

        self.model = AudioModelWT.from_pretrained(
            pretrained_model_name_or_path=model_name, config=model_config
        )

        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(self.device)

    def preprocess_wave(self, x: torch.Tensor) -> torch.Tensor:
        """Extracts features for wav2vec
        Apply padding to max length of audio

        Args:
            x (torch.Tensor): Input data

        Returns:
            np.ndarray: Preprocessed data
        """
        a_data = self.processor(
            x,
            sampling_rate=self.sr,
            return_tensors="pt",
            padding="max_length",
            max_length=self.sr * self.win_max_length,
        )
        return a_data["input_values"][0]

    def __call__(
        self, waveform: torch.Tensor
    ) -> tuple[dict[torch.Tensor], torch.Tensor]:
        """Extracts acoustic features
        Apply padding to max length of audio

        Args:
            wave (torch.Tensor): wave

        Returns:
            torch.Tensor: Extracted features
        """
        waveform = self.preprocess_wave(waveform).unsqueeze(0).to(self.device)

        with torch.no_grad():
            if self.with_features:
                preds, features = self.model(waveform, with_features=self.with_features)
            else:
                preds = self.model(waveform, with_features=self.with_features)

            predicts = {
                "emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(),
                "sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(),
            }

        return (
            (predicts, features.detach().cpu().squeeze())
            if self.with_features
            else (predicts, None)
        )


class Tmodel(nn.Module):
    def __init__(
        self,
        input_size: int = int(1024),
        activation=nn.SELU(),
        feature_size1=256,
        feature_size2=64,
        num_heads=1,
        num_layers=2,
        n_emo=7,
        n_sent=3,
    ):
        super(Tmodel, self).__init__()
        self.feature_text_dynamic = nn.ModuleList(
            [
                SelfTransformer(input_size=input_size, num_heads=num_heads)
                for i in range(num_layers)
            ]
        )
        self.fcl = nn.Linear(input_size, feature_size1)
        self.activation = activation
        self.feature_emo = nn.Linear(feature_size1, feature_size2)
        self.feature_sent = nn.Linear(feature_size1, feature_size2)
        self.fc_emo = nn.Linear(feature_size2, n_emo)
        self.fc_sent = nn.Linear(feature_size2, n_sent)

    def get_features(self, t):
        for i, l in enumerate(self.feature_text_dynamic):
            self.features = l(t)

    def forward(self, t):
        self.get_features(t)
        represent = self.activation(torch.mean(t, axis=1))
        represent = self.activation(self.fcl(represent))
        represent_emo = self.activation(self.feature_emo(represent))
        represent_sent = self.activation(self.feature_sent(represent))
        prob_emo = self.fc_emo(represent_emo)
        prob_sent = self.fc_sent(represent_sent)
        return prob_emo, prob_sent


class TextFeatureExtractor:
    def __init__(
        self,
        checkpoint_url: str,
        folder_path: str,
        device: torch.device,
        with_features: bool = True,
    ) -> None:

        self.device = device
        self.with_features = with_features

        model_name_bert = "julian-schelb/roberta-ner-multilingual"
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_bert, add_prefix_space=True
        )
        self.model_bert = AutoModel.from_pretrained(model_name_bert)

        checkpoint_path = load_model(checkpoint_url, folder_path)

        self.model = Tmodel()
        self.model.load_state_dict(
            torch.load(checkpoint_path, map_location=self.device)
        )
        self.model.to(self.device)

    def preprocess_text(self, text: torch.Tensor) -> torch.Tensor:
        if text != "" and str(text) != "nan":
            inputs = self.tokenizer(
                text.lower(),
                padding="max_length",
                truncation="longest_first",
                return_tensors="pt",
                max_length=6,
            ).to(self.device)
            with torch.no_grad():
                self.model_bert = self.model_bert.to(self.device)
                outputs = (
                    self.model_bert(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                    )
                    .last_hidden_state.cpu()
                    .detach()
                )
        else:
            outputs = torch.zeros((1, 6, 1024))
        return outputs

    def __call__(self, text: torch.Tensor) -> tuple[dict[torch.Tensor], torch.Tensor]:
        text_features = self.preprocess_text(text)

        with torch.no_grad():
            if self.with_features:
                pred_emo, pred_sent = self.model(text_features.float().to(self.device))
                temporal_features = self.model.features
            else:
                pred_emo, pred_sent = self.model(text_features.float().to(self.device))

            predicts = {
                "emo": F.softmax(pred_emo, dim=-1).detach().cpu().squeeze(),
                "sen": F.softmax(pred_sent, dim=-1).detach().cpu().squeeze(),
            }

        return (
            (predicts, temporal_features.detach().cpu().squeeze())
            if self.with_features
            else (predicts, None)
        )


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=stride,
            padding=0,
            bias=False,
        )
        self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)

        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, padding="same", bias=False
        )
        self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)

        self.conv3 = nn.Conv2d(
            out_channels,
            out_channels * self.expansion,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.batch_norm3 = nn.BatchNorm2d(
            out_channels * self.expansion, eps=0.001, momentum=0.99
        )

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))

        x = self.relu(self.batch_norm2(self.conv2(x)))

        x = self.conv3(x)
        x = self.batch_norm3(x)

        # downsample if needed
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        # add identity
        x += identity
        x = self.relu(x)

        return x


class Conv2dSame(torch.nn.Conv2d):
    def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ih, iw = x.size()[-2:]

        pad_h = self.calc_same_pad(
            i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]
        )
        pad_w = self.calc_same_pad(
            i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]
        )

        if pad_h > 0 or pad_w > 0:
            x = F.pad(
                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
            )
        return F.conv2d(
            x,
            self.weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class ResNet(nn.Module):
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv_layer_s2_same = Conv2dSame(
            num_channels, 64, 7, stride=2, groups=1, bias=False
        )
        self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)

        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512 * ResBlock.expansion, 512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)

    def extract_features_four(self, x):
        x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))
        x = self.max_pool(x)
        # print(x.shape)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def extract_features(self, x):
        x = self.extract_features_four(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

    def forward(self, x):
        x = self.extract_features(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x

    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []

        if stride != 1 or self.in_channels != planes * ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    planes * ResBlock.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                    padding=0,
                ),
                nn.BatchNorm2d(planes * ResBlock.expansion, eps=0.001, momentum=0.99),
            )

        layers.append(
            ResBlock(
                self.in_channels, planes, i_downsample=ii_downsample, stride=stride
            )
        )
        self.in_channels = planes * ResBlock.expansion

        for i in range(blocks - 1):
            layers.append(ResBlock(self.in_channels, planes))

        return nn.Sequential(*layers)


def ResNet50(num_classes, channels=3):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)


class Vmodel(nn.Module):
    def __init__(
        self,
        input_size=512,
        activation=nn.SELU(),
        feature_size=64,
        num_heads=1,
        num_layers=1,
        positional_encoding=False,
        n_emo=7,
        n_sent=3,
    ):
        super(Vmodel, self).__init__()

        self.feature_video_dynamic = nn.ModuleList(
            [
                TransformerLayer(
                    input_dim=input_size,
                    num_heads=num_heads,
                    positional_encoding=positional_encoding,
                )
                for i in range(num_layers)
            ]
        )

        self.fcl = nn.Linear(input_size, feature_size)
        self.activation = activation
        self.feature_emo = nn.Linear(feature_size, feature_size)
        self.feature_sent = nn.Linear(feature_size, feature_size)
        self.fc_emo = nn.Linear(feature_size, n_emo)
        self.fc_sent = nn.Linear(feature_size, n_sent)

    def forward(self, x, with_features=False):
        for i, l in enumerate(self.feature_video_dynamic):
            x = l(x, x, x)

        represent = self.activation(torch.mean(x, axis=1))
        represent = self.activation(self.fcl(represent))
        represent_emo = self.activation(self.feature_emo(represent))
        represent_sent = self.activation(self.feature_sent(represent))
        prob_emo = self.fc_emo(represent_emo)
        prob_sent = self.fc_sent(represent_sent)

        if with_features:
            return {"emo": prob_emo, "sen": prob_sent}, x
        else:
            return {"emo": prob_emo, "sen": prob_sent}


class VideoModelLoader:
    def __init__(
        self,
        face_checkpoint_url: str,
        emotion_checkpoint_url: str,
        emo_sent_checkpoint_url: str,
        folder_path: str,
        device: torch.device,
    ) -> None:
        self.device = device

        # YOLO face recognition model initialization
        face_model_path = load_model(face_checkpoint_url, folder_path)
        emotion_video_model_path = load_model(emotion_checkpoint_url, folder_path)
        emo_sent_video_model_path = load_model(emo_sent_checkpoint_url, folder_path)

        self.face_model = YOLO(face_model_path)

        # EmoAffectet model initialization (static model)
        self.emo_affectnet_model = ResNet50(num_classes=7, channels=3)
        self.emo_affectnet_model.load_state_dict(
            torch.load(emotion_video_model_path, map_location=self.device)
        )
        self.emo_affectnet_model.to(self.device).eval()

        # Visual emotion and sentiment recognition model (dynamic model)
        self.emo_sent_video_model = Vmodel()
        self.emo_sent_video_model.load_state_dict(
            torch.load(emo_sent_video_model_path, map_location=self.device)
        )
        self.emo_sent_video_model.to(self.device).eval()

    def extract_zeros_features(self):
        zeros = torch.unsqueeze(torch.zeros((3, 224, 224)), 0).to(self.device)
        zeros_features = self.emo_affectnet_model.extract_features(zeros)
        return zeros_features.cpu().detach().numpy()[0]


class VideoFeatureExtractor:
    def __init__(
        self,
        model_loader: VideoModelLoader,
        file_path: str,
        target_fps: int = 5,
        with_features: bool = True,
    ) -> None:
        self.model_loader = model_loader
        self.with_features = with_features

        # Video options
        self.cap = cv2.VideoCapture(file_path)
        self.w, self.h, self.fps, self.frame_number = (
            int(self.cap.get(x))
            for x in (
                cv2.CAP_PROP_FRAME_WIDTH,
                cv2.CAP_PROP_FRAME_HEIGHT,
                cv2.CAP_PROP_FPS,
                cv2.CAP_PROP_FRAME_COUNT,
            )
        )
        self.dur = self.frame_number / self.fps
        self.target_fps = target_fps
        self.frame_interval = int(self.fps / target_fps)

        # Extract zero features if no face found in frame
        self.zeros_features = self.model_loader.extract_zeros_features()

        # Dictionaries with facial features and faces
        self.facial_features = {}
        self.faces = {}

    def preprocess_frame(self, frame: np.ndarray, counter: int) -> None:
        curr_fr = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = self.model_loader.face_model.track(
            curr_fr,
            persist=True,
            imgsz=640,
            conf=0.01,
            iou=0.5,
            augment=False,
            device=self.model_loader.device,
            verbose=False,
        )

        need_features = np.zeros(512)
        count_face = 0

        if results[0].boxes.xyxy.cpu().tolist() != []:
            for i in results[0].boxes:
                idx_box = i.id.int().cpu().tolist()[0] if i.id else -1
                box = i.xyxy.int().cpu().tolist()[0]
                startX, startY = max(0, box[0]), max(0, box[1])
                endX, endY = min(self.w - 1, box[2]), min(self.h - 1, box[3])

                face_region = curr_fr[startY:endY, startX:endX]
                norm_face_region = pth_processing(Image.fromarray(face_region))
                with torch.no_grad():
                    curr_features = (
                        self.model_loader.emo_affectnet_model.extract_features(
                            norm_face_region.to(self.model_loader.device)
                        )
                    )
                need_features += curr_features.cpu().detach().numpy()[0]
                count_face += 1

                if idx_box in self.faces:
                    self.faces[idx_box].update({counter: face_region})
                else:
                    self.faces[idx_box] = {counter: face_region}

            need_features /= count_face
            self.facial_features[counter] = need_features
        else:
            if counter - 1 in self.facial_features:
                self.facial_features[counter] = self.facial_features[counter - 1]
            else:
                self.facial_features[counter] = self.zeros_features

    def preprocess_video(self) -> None:
        counter = 0

        while True:
            ret, frame = self.cap.read()
            if not ret:
                break
            if counter % self.frame_interval == 0:
                self.preprocess_frame(frame, counter)
            counter += 1

    def __call__(
        self, window: dict, win_max_length: int, sr: int = 16000
    ) -> tuple[dict[torch.Tensor], torch.Tensor]:

        curr_idx_frames = get_idx_frames_in_windows(
            list(self.facial_features.keys()), window, self.fps, sr
        )

        video_features = np.array(list(self.facial_features.values()))

        curr_features = video_features[curr_idx_frames, :]

        if len(curr_features) < self.target_fps * win_max_length:
            diff = self.target_fps * win_max_length - len(curr_features)
            curr_features = np.concatenate(
                [curr_features, [curr_features[-1]] * diff], axis=0
            )

        curr_features = (
            torch.FloatTensor(curr_features).unsqueeze(0).to(self.model_loader.device)
        )

        with torch.no_grad():
            if self.with_features:
                preds, features = self.model_loader.emo_sent_video_model(
                    curr_features, with_features=self.with_features
                )
            else:
                preds = self.model_loader.emo_sent_video_model(
                    curr_features, with_features=self.with_features
                )

            predicts = {
                "emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(),
                "sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(),
            }

        return (
            (predicts, features.detach().cpu().squeeze())
            if self.with_features
            else (predicts, None)
        )