""" File: submit.py Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov Description: Event handler for Gradio app to submit. License: MIT License """ import spaces import torch import pandas as pd import cv2 import gradio as gr # Importing necessary components for the Gradio app from app.config import config_data from app.utils import ( convert_video_to_audio, readetect_speech, slice_audio, find_intersections, calculate_mode, find_nearest_frames, ) from app.plots import ( get_evenly_spaced_frame_indices, plot_audio, display_frame_info, plot_images, plot_predictions, ) from app.data_init import ( read_audio, get_speech_timestamps, vad_model, video_model, asr, audio_model, text_model, ) from app.load_models import VideoFeatureExtractor from app.components import html_message @spaces.GPU def event_handler_submit( video: str, ) -> tuple[gr.HTML, gr.Plot, gr.Plot, gr.Plot, gr.Plot]: audio_file_path = convert_video_to_audio(file_path=video, sr=config_data.General_SR) wav, vad_info = readetect_speech( file_path=audio_file_path, read_audio=read_audio, get_speech_timestamps=get_speech_timestamps, vad_model=vad_model, sr=config_data.General_SR, ) audio_windows = slice_audio( start_time=config_data.General_START_TIME, end_time=int(len(wav)), win_max_length=int(config_data.General_WIN_MAX_LENGTH * config_data.General_SR), win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR), win_min_length=int(config_data.General_WIN_MIN_LENGTH * config_data.General_SR), ) intersections = find_intersections( x=audio_windows, y=vad_info, min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR, ) vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False) vfe.preprocess_video() transcriptions, total_text = asr(wav, audio_windows) window_frames = [] preds_emo = [] preds_sen = [] for w_idx, window in enumerate(audio_windows): a_w = intersections[w_idx] if not a_w["speech"]: a_pred = None else: wave = wav[a_w["start"] : a_w["end"]].clone() a_pred, _ = audio_model(wave) v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH) t_pred, _ = text_model(transcriptions[w_idx][0]) if a_pred: pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3 pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3 else: pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2 pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2 frames = list( range( int(window["start"] * vfe.fps / config_data.General_SR) + 1, int(window["end"] * vfe.fps / config_data.General_SR) + 2, ) ) preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames)) preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames)) window_frames.extend(frames) if max(window_frames) < vfe.frame_number: missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1)) window_frames.extend(missed_frames) preds_emo.extend([preds_emo[-1]] * len(missed_frames)) preds_sen.extend([preds_sen[-1]] * len(missed_frames)) df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"]) df_pred["frames"] = window_frames df_pred["pred_emo"] = preds_emo df_pred["pred_sent"] = preds_sen df_pred = df_pred.groupby("frames").agg( { "pred_emo": calculate_mode, "pred_sent": calculate_mode, } ) frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9) num_frames = len(wav) time_axis = [i / config_data.General_SR for i in range(num_frames)] plt_audio = plot_audio(time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2)) all_idx_faces = list(vfe.faces[1].keys()) need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces) faces = [] for idx_frame, idx_faces in zip(frame_indices, need_idx_faces): cur_face = cv2.resize( vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA ) faces.append( display_frame_info( cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3 ) ) plt_faces = plot_images(faces) plt_emo = plot_predictions( df_pred, "pred_emo", "Emotion", list(config_data.General_DICT_EMO), (12, 2.5), [i + 1 for i in frame_indices], 2, ) plt_sent = plot_predictions( df_pred, "pred_sent", "Sentiment", list(config_data.General_DICT_SENT), (12, 1.5), [i + 1 for i in frame_indices], 2, ) return ( html_message( message=config_data.InformationMessages_NOTI_RESULTS[1], error=False, visible=False, ), gr.Plot(value=plt_audio, visible=True), gr.Plot(value=plt_faces, visible=True), gr.Plot(value=plt_emo, visible=True), gr.Plot(value=plt_sent, visible=True), )