DmitryRyumin commited on
Commit
15b7f31
Β·
1 Parent(s): 5616336
app/data_init.py CHANGED
@@ -21,6 +21,7 @@ vad_model, vad_utils = torch.hub.load(
21
  repo_or_dir=config_data.StaticPaths_VAD_MODEL,
22
  model="silero_vad",
23
  force_reload=False,
 
24
  onnx=False,
25
  )
26
 
 
21
  repo_or_dir=config_data.StaticPaths_VAD_MODEL,
22
  model="silero_vad",
23
  force_reload=False,
24
+ verbose=False,
25
  onnx=False,
26
  )
27
 
app/event_handlers/clear.py CHANGED
@@ -11,11 +11,19 @@ import gradio as gr
11
  from app.config import config_data
12
 
13
 
14
- def event_handler_clear() -> (
15
- tuple[
16
- gr.Video, gr.Button, gr.Button, gr.Textbox, gr.Plot, gr.Plot, gr.Plot, gr.Plot
17
- ]
18
- ):
 
 
 
 
 
 
 
 
19
  return (
20
  gr.Video(value=None),
21
  gr.Button(interactive=False),
@@ -30,4 +38,7 @@ def event_handler_clear() -> (
30
  gr.Plot(value=None, visible=False),
31
  gr.Plot(value=None, visible=False),
32
  gr.Plot(value=None, visible=False),
 
 
 
33
  )
 
11
  from app.config import config_data
12
 
13
 
14
+ def event_handler_clear() -> tuple[
15
+ gr.Video,
16
+ gr.Button,
17
+ gr.Button,
18
+ gr.Textbox,
19
+ gr.Plot,
20
+ gr.Plot,
21
+ gr.Plot,
22
+ gr.Plot,
23
+ gr.Row,
24
+ gr.Textbox,
25
+ gr.Textbox,
26
+ ]:
27
  return (
28
  gr.Video(value=None),
29
  gr.Button(interactive=False),
 
38
  gr.Plot(value=None, visible=False),
39
  gr.Plot(value=None, visible=False),
40
  gr.Plot(value=None, visible=False),
41
+ gr.Row(visible=False),
42
+ gr.Textbox(value=None, visible=False),
43
+ gr.Textbox(value=None, visible=False),
44
  )
app/event_handlers/event_handlers.py CHANGED
@@ -22,6 +22,9 @@ def setup_app_event_handlers(
22
  faces,
23
  emotion_stats,
24
  sent_stats,
 
 
 
25
  ):
26
  gr.on(
27
  triggers=[video.change, video.upload, video.stop_recording, video.clear],
@@ -40,6 +43,9 @@ def setup_app_event_handlers(
40
  faces,
41
  emotion_stats,
42
  sent_stats,
 
 
 
43
  ],
44
  queue=True,
45
  )
@@ -56,6 +62,9 @@ def setup_app_event_handlers(
56
  faces,
57
  emotion_stats,
58
  sent_stats,
 
 
 
59
  ],
60
  queue=True,
61
  )
 
22
  faces,
23
  emotion_stats,
24
  sent_stats,
25
+ time_row,
26
+ video_duration,
27
+ calculate_time,
28
  ):
29
  gr.on(
30
  triggers=[video.change, video.upload, video.stop_recording, video.clear],
 
43
  faces,
44
  emotion_stats,
45
  sent_stats,
46
+ time_row,
47
+ video_duration,
48
+ calculate_time,
49
  ],
50
  queue=True,
51
  )
 
62
  faces,
63
  emotion_stats,
64
  sent_stats,
65
+ time_row,
66
+ video_duration,
67
+ calculate_time,
68
  ],
69
  queue=True,
70
  )
app/event_handlers/submit.py CHANGED
@@ -14,6 +14,7 @@ import gradio as gr
14
  # Importing necessary components for the Gradio app
15
  from app.config import config_data
16
  from app.utils import (
 
17
  convert_video_to_audio,
18
  readetect_speech,
19
  slice_audio,
@@ -44,126 +45,144 @@ from app.load_models import VideoFeatureExtractor
44
  @spaces.GPU
45
  def event_handler_submit(
46
  video: str,
47
- ) -> tuple[gr.Textbox, gr.Plot, gr.Plot, gr.Plot, gr.Plot]:
48
- if video:
49
- if video.split(".")[-1] == "webm":
50
- video = convert_webm_to_mp4(video)
51
-
52
- audio_file_path = convert_video_to_audio(file_path=video, sr=config_data.General_SR)
53
- wav, vad_info = readetect_speech(
54
- file_path=audio_file_path,
55
- read_audio=read_audio,
56
- get_speech_timestamps=get_speech_timestamps,
57
- vad_model=vad_model,
58
- sr=config_data.General_SR,
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- audio_windows = slice_audio(
62
- start_time=config_data.General_START_TIME,
63
- end_time=int(len(wav)),
64
- win_max_length=int(config_data.General_WIN_MAX_LENGTH * config_data.General_SR),
65
- win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
66
- win_min_length=int(config_data.General_WIN_MIN_LENGTH * config_data.General_SR),
67
- )
 
 
 
 
68
 
69
- intersections = find_intersections(
70
- x=audio_windows,
71
- y=vad_info,
72
- min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
73
- )
74
 
75
- vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
76
- vfe.preprocess_video()
77
-
78
- transcriptions, total_text = asr(wav, audio_windows)
79
-
80
- window_frames = []
81
- preds_emo = []
82
- preds_sen = []
83
- for w_idx, window in enumerate(audio_windows):
84
- a_w = intersections[w_idx]
85
- if not a_w["speech"]:
86
- a_pred = None
87
- else:
88
- wave = wav[a_w["start"] : a_w["end"]].clone()
89
- a_pred, _ = audio_model(wave)
90
-
91
- v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
92
-
93
- t_pred, _ = text_model(transcriptions[w_idx][0])
94
-
95
- if a_pred:
96
- pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
97
- pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
98
- else:
99
- pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
100
- pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
101
-
102
- frames = list(
103
- range(
104
- int(window["start"] * vfe.fps / config_data.General_SR) + 1,
105
- int(window["end"] * vfe.fps / config_data.General_SR) + 2,
 
106
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  )
108
- preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
109
- preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
110
- window_frames.extend(frames)
111
-
112
- if max(window_frames) < vfe.frame_number:
113
- missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
114
- window_frames.extend(missed_frames)
115
- preds_emo.extend([preds_emo[-1]] * len(missed_frames))
116
- preds_sen.extend([preds_sen[-1]] * len(missed_frames))
117
-
118
- df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
119
- df_pred["frames"] = window_frames
120
- df_pred["pred_emo"] = preds_emo
121
- df_pred["pred_sent"] = preds_sen
122
-
123
- df_pred = df_pred.groupby("frames").agg(
124
- {
125
- "pred_emo": calculate_mode,
126
- "pred_sent": calculate_mode,
127
- }
128
- )
129
 
130
- frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
131
- num_frames = len(wav)
132
- time_axis = [i / config_data.General_SR for i in range(num_frames)]
133
- plt_audio = plot_audio(time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2))
134
-
135
- all_idx_faces = list(vfe.faces[1].keys())
136
- need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
137
- faces = []
138
- for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
139
- cur_face = cv2.resize(
140
- vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
141
  )
142
- faces.append(
143
- display_frame_info(
144
- cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
 
 
 
 
 
 
 
 
 
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
- plt_faces = plot_images(faces)
148
-
149
- plt_emo = plot_predictions(
150
- df_pred,
151
- "pred_emo",
152
- "Emotion",
153
- list(config_data.General_DICT_EMO),
154
- (12, 2.5),
155
- [i + 1 for i in frame_indices],
156
- 2,
157
- )
158
- plt_sent = plot_predictions(
159
- df_pred,
160
- "pred_sent",
161
- "Sentiment",
162
- list(config_data.General_DICT_SENT),
163
- (12, 1.5),
164
- [i + 1 for i in frame_indices],
165
- 2,
166
- )
167
 
168
  return (
169
  gr.Textbox(
@@ -176,4 +195,10 @@ def event_handler_submit(
176
  gr.Plot(value=plt_faces, visible=True),
177
  gr.Plot(value=plt_emo, visible=True),
178
  gr.Plot(value=plt_sent, visible=True),
 
 
 
 
 
 
179
  )
 
14
  # Importing necessary components for the Gradio app
15
  from app.config import config_data
16
  from app.utils import (
17
+ Timer,
18
  convert_video_to_audio,
19
  readetect_speech,
20
  slice_audio,
 
45
  @spaces.GPU
46
  def event_handler_submit(
47
  video: str,
48
+ ) -> tuple[
49
+ gr.Textbox,
50
+ gr.Plot,
51
+ gr.Plot,
52
+ gr.Plot,
53
+ gr.Plot,
54
+ gr.Row,
55
+ gr.Textbox,
56
+ gr.Textbox,
57
+ ]:
58
+ with Timer() as t:
59
+ if video:
60
+ if video.split(".")[-1] == "webm":
61
+ video = convert_webm_to_mp4(video)
62
+
63
+ audio_file_path = convert_video_to_audio(
64
+ file_path=video, sr=config_data.General_SR
65
+ )
66
+ wav, vad_info = readetect_speech(
67
+ file_path=audio_file_path,
68
+ read_audio=read_audio,
69
+ get_speech_timestamps=get_speech_timestamps,
70
+ vad_model=vad_model,
71
+ sr=config_data.General_SR,
72
+ )
73
 
74
+ audio_windows = slice_audio(
75
+ start_time=config_data.General_START_TIME,
76
+ end_time=int(len(wav)),
77
+ win_max_length=int(
78
+ config_data.General_WIN_MAX_LENGTH * config_data.General_SR
79
+ ),
80
+ win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
81
+ win_min_length=int(
82
+ config_data.General_WIN_MIN_LENGTH * config_data.General_SR
83
+ ),
84
+ )
85
 
86
+ intersections = find_intersections(
87
+ x=audio_windows,
88
+ y=vad_info,
89
+ min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
90
+ )
91
 
92
+ vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
93
+ vfe.preprocess_video()
94
+
95
+ transcriptions, total_text = asr(wav, audio_windows)
96
+
97
+ window_frames = []
98
+ preds_emo = []
99
+ preds_sen = []
100
+ for w_idx, window in enumerate(audio_windows):
101
+ a_w = intersections[w_idx]
102
+ if not a_w["speech"]:
103
+ a_pred = None
104
+ else:
105
+ wave = wav[a_w["start"] : a_w["end"]].clone()
106
+ a_pred, _ = audio_model(wave)
107
+
108
+ v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
109
+
110
+ t_pred, _ = text_model(transcriptions[w_idx][0])
111
+
112
+ if a_pred:
113
+ pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
114
+ pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
115
+ else:
116
+ pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
117
+ pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
118
+
119
+ frames = list(
120
+ range(
121
+ int(window["start"] * vfe.fps / config_data.General_SR) + 1,
122
+ int(window["end"] * vfe.fps / config_data.General_SR) + 2,
123
+ )
124
  )
125
+ preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
126
+ preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
127
+ window_frames.extend(frames)
128
+
129
+ if max(window_frames) < vfe.frame_number:
130
+ missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
131
+ window_frames.extend(missed_frames)
132
+ preds_emo.extend([preds_emo[-1]] * len(missed_frames))
133
+ preds_sen.extend([preds_sen[-1]] * len(missed_frames))
134
+
135
+ df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
136
+ df_pred["frames"] = window_frames
137
+ df_pred["pred_emo"] = preds_emo
138
+ df_pred["pred_sent"] = preds_sen
139
+
140
+ df_pred = df_pred.groupby("frames").agg(
141
+ {
142
+ "pred_emo": calculate_mode,
143
+ "pred_sent": calculate_mode,
144
+ }
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
148
+ num_frames = len(wav)
149
+ time_axis = [i / config_data.General_SR for i in range(num_frames)]
150
+ plt_audio = plot_audio(
151
+ time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2)
 
 
 
 
 
 
152
  )
153
+
154
+ all_idx_faces = list(vfe.faces[1].keys())
155
+ need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
156
+ faces = []
157
+ for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
158
+ cur_face = cv2.resize(
159
+ vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
160
+ )
161
+ faces.append(
162
+ display_frame_info(
163
+ cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
164
+ )
165
  )
166
+ plt_faces = plot_images(faces)
167
+
168
+ plt_emo = plot_predictions(
169
+ df_pred,
170
+ "pred_emo",
171
+ "Emotion",
172
+ list(config_data.General_DICT_EMO),
173
+ (12, 2.5),
174
+ [i + 1 for i in frame_indices],
175
+ 2,
176
+ )
177
+ plt_sent = plot_predictions(
178
+ df_pred,
179
+ "pred_sent",
180
+ "Sentiment",
181
+ list(config_data.General_DICT_SENT),
182
+ (12, 1.5),
183
+ [i + 1 for i in frame_indices],
184
+ 2,
185
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return (
188
  gr.Textbox(
 
195
  gr.Plot(value=plt_faces, visible=True),
196
  gr.Plot(value=plt_emo, visible=True),
197
  gr.Plot(value=plt_sent, visible=True),
198
+ gr.Row(visible=True),
199
+ gr.Textbox(
200
+ value=config_data.InformationMessages_VIDEO_DURATION.format(vfe.dur),
201
+ visible=True,
202
+ ),
203
+ gr.Textbox(value=t, visible=True),
204
  )
app/load_models.py CHANGED
@@ -20,7 +20,15 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
20
  Wav2Vec2PreTrainedModel,
21
  )
22
 
23
- from transformers import AutoConfig, Wav2Vec2Processor, AutoTokenizer, AutoModel
 
 
 
 
 
 
 
 
24
 
25
  from app.utils import pth_processing, get_idx_frames_in_windows
26
 
@@ -838,9 +846,6 @@ class VideoFeatureExtractor:
838
  need_features += curr_features.cpu().detach().numpy()[0]
839
  count_face += 1
840
 
841
- # face_region = cv2.resize(face_region, (224,224), interpolation = cv2.INTER_AREA)
842
- # face_region = display_frame_info(face_region, 'Frame: {}'.format(count_face), box_scale=.3)
843
-
844
  if idx_box in self.faces:
845
  self.faces[idx_box].update({counter: face_region})
846
  else:
 
20
  Wav2Vec2PreTrainedModel,
21
  )
22
 
23
+ from transformers import (
24
+ AutoConfig,
25
+ Wav2Vec2Processor,
26
+ AutoTokenizer,
27
+ AutoModel,
28
+ logging,
29
+ )
30
+
31
+ logging.set_verbosity_error()
32
 
33
  from app.utils import pth_processing, get_idx_frames_in_windows
34
 
 
846
  need_features += curr_features.cpu().detach().numpy()[0]
847
  count_face += 1
848
 
 
 
 
849
  if idx_box in self.faces:
850
  self.faces[idx_box].update({counter: face_region})
851
  else:
app/tabs.py CHANGED
@@ -33,7 +33,7 @@ def app_tab():
33
  show_label=True,
34
  interactive=True,
35
  visible=True,
36
- mirror_webcam=True,
37
  include_audio=True,
38
  elem_classes="video",
39
  autoplay=False,
@@ -123,6 +123,50 @@ def app_tab():
123
  elem_classes="sent-stats",
124
  )
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return (
127
  video,
128
  clear,
@@ -132,6 +176,9 @@ def app_tab():
132
  faces,
133
  emotion_stats,
134
  sent_stats,
 
 
 
135
  )
136
 
137
 
 
33
  show_label=True,
34
  interactive=True,
35
  visible=True,
36
+ mirror_webcam=False,
37
  include_audio=True,
38
  elem_classes="video",
39
  autoplay=False,
 
123
  elem_classes="sent-stats",
124
  )
125
 
126
+ with gr.Row(
127
+ visible=False,
128
+ render=True,
129
+ variant="default",
130
+ elem_classes="time-container",
131
+ ) as time_row:
132
+ video_duration = gr.Textbox(
133
+ value=None,
134
+ max_lines=1,
135
+ placeholder=None,
136
+ label=None,
137
+ info=None,
138
+ show_label=False,
139
+ container=False,
140
+ interactive=False,
141
+ visible=False,
142
+ autofocus=False,
143
+ autoscroll=True,
144
+ render=True,
145
+ type="text",
146
+ show_copy_button=False,
147
+ max_length=50,
148
+ elem_classes="video_duration",
149
+ )
150
+
151
+ calculate_time = gr.Textbox(
152
+ value=None,
153
+ max_lines=1,
154
+ placeholder=None,
155
+ label=None,
156
+ info=None,
157
+ show_label=False,
158
+ container=False,
159
+ interactive=False,
160
+ visible=False,
161
+ autofocus=False,
162
+ autoscroll=True,
163
+ render=True,
164
+ type="text",
165
+ show_copy_button=False,
166
+ max_length=50,
167
+ elem_classes="calculate_time",
168
+ )
169
+
170
  return (
171
  video,
172
  clear,
 
176
  faces,
177
  emotion_stats,
178
  sent_stats,
179
+ time_row,
180
+ video_duration,
181
+ calculate_time,
182
  )
183
 
184
 
app/utils.py CHANGED
@@ -5,6 +5,7 @@ Description: Utility functions.
5
  License: MIT License
6
  """
7
 
 
8
  import torch
9
  import os
10
  import subprocess
@@ -17,10 +18,26 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
17
  from pathlib import Path
18
  from contextlib import suppress
19
  from urllib.parse import urlparse
 
20
 
21
  from typing import Callable
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_model(
25
  model_url: str, folder_path: str, force_reload: bool = False
26
  ) -> str | None:
@@ -259,7 +276,11 @@ class ASRModel:
259
  transcription = self.processor.batch_decode(
260
  predicted_ids, skip_special_tokens=False
261
  )
262
- texts.append(re.findall(r"> ([^<>]+)", transcription[0]))
 
 
 
 
263
 
264
  # for drawing
265
  input_features = self.processor(
 
5
  License: MIT License
6
  """
7
 
8
+ import time
9
  import torch
10
  import os
11
  import subprocess
 
18
  from pathlib import Path
19
  from contextlib import suppress
20
  from urllib.parse import urlparse
21
+ from contextlib import ContextDecorator
22
 
23
  from typing import Callable
24
 
25
 
26
+ class Timer(ContextDecorator):
27
+ """Context manager for measuring code execution time"""
28
+
29
+ def __enter__(self):
30
+ self.start = time.time()
31
+ return self
32
+
33
+ def __exit__(self, *args):
34
+ self.end = time.time()
35
+ self.execution_time = f"Inference time: {self.end - self.start:.2f} seconds"
36
+
37
+ def __str__(self):
38
+ return self.execution_time
39
+
40
+
41
  def load_model(
42
  model_url: str, folder_path: str, force_reload: bool = False
43
  ) -> str | None:
 
276
  transcription = self.processor.batch_decode(
277
  predicted_ids, skip_special_tokens=False
278
  )
279
+ curr_text = re.findall(r"> ([^<>]+)", transcription[0])
280
+ if curr_text:
281
+ texts.append(curr_text)
282
+ else:
283
+ texts.appemd("")
284
 
285
  # for drawing
286
  input_features = self.processor(
config.toml CHANGED
@@ -30,6 +30,7 @@ NOTI_RESULTS = [
30
  "Video uploaded, you can perform calculations",
31
  ]
32
  REC_TEXT = "Recognized text"
 
33
 
34
  [OtherMessages]
35
  CLEAR = "Clear"
 
30
  "Video uploaded, you can perform calculations",
31
  ]
32
  REC_TEXT = "Recognized text"
33
+ VIDEO_DURATION = "Video duration: {:.2f} seconds"
34
 
35
  [OtherMessages]
36
  CLEAR = "Clear"