jhj0517 commited on
Commit
1a63918
·
unverified ·
2 Parent(s): fb62be2 f197459

Merge pull request #366 from jhj0517/feature/enable-word-timestamps

Browse files
app.py CHANGED
@@ -53,7 +53,7 @@ class App:
53
  dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
54
  value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
55
  else whisper_params["lang"], label=_("Language"))
56
- dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format"))
57
  with gr.Row():
58
  cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
59
  interactive=True)
 
53
  dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
54
  value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
55
  else whisper_params["lang"], label=_("Language"))
56
+ dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format"))
57
  with gr.Row():
58
  cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
59
  interactive=True)
modules/diarize/diarize_pipeline.py CHANGED
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
7
  from typing import Optional, Union
8
  import torch
9
 
 
10
  from modules.utils.paths import DIARIZATION_MODELS_DIR
11
  from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
12
 
@@ -44,7 +45,8 @@ class DiarizationPipeline:
44
  def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
45
  transcript_segments = transcript_result["segments"]
46
  for seg in transcript_segments:
47
- seg = seg.dict()
 
48
  # assign speaker to segment (if any)
49
  diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
50
  seg['start'])
@@ -64,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
64
  seg["speaker"] = speaker
65
 
66
  # assign speaker to words
67
- if 'words' in seg:
68
  for word in seg['words']:
69
  if 'start' in word:
70
  diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
@@ -89,7 +91,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
89
  return transcript_result
90
 
91
 
92
- class Segment:
93
  def __init__(self, start, end, speaker=None):
94
  self.start = start
95
  self.end = end
 
7
  from typing import Optional, Union
8
  import torch
9
 
10
+ from modules.whisper.data_classes import *
11
  from modules.utils.paths import DIARIZATION_MODELS_DIR
12
  from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
13
 
 
45
  def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
46
  transcript_segments = transcript_result["segments"]
47
  for seg in transcript_segments:
48
+ if isinstance(seg, Segment):
49
+ seg = seg.model_dump()
50
  # assign speaker to segment (if any)
51
  diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
52
  seg['start'])
 
66
  seg["speaker"] = speaker
67
 
68
  # assign speaker to words
69
+ if 'words' in seg and seg['words'] is not None:
70
  for word in seg['words']:
71
  if 'start' in word:
72
  diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
 
91
  return transcript_result
92
 
93
 
94
+ class DiarizationSegment:
95
  def __init__(self, start, end, speaker=None):
96
  self.start = start
97
  self.end = end
modules/translation/deepl_api.py CHANGED
@@ -139,37 +139,27 @@ class DeepLAPI:
139
  )
140
 
141
  files_info = {}
142
- for fileobj in fileobjs:
143
- file_path = fileobj
144
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
145
-
146
- if file_ext == ".srt":
147
- parsed_dicts = parse_srt(file_path=file_path)
148
-
149
- elif file_ext == ".vtt":
150
- parsed_dicts = parse_vtt(file_path=file_path)
151
 
152
  batch_size = self.max_text_batch_size
153
- for batch_start in range(0, len(parsed_dicts), batch_size):
154
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
155
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
156
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
157
  target_lang, is_pro)
158
  for i, translated_text in enumerate(translated_texts):
159
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
160
- progress(batch_end / len(parsed_dicts), desc="Translating..")
161
-
162
- if file_ext == ".srt":
163
- subtitle = get_serialized_srt(parsed_dicts)
164
- elif file_ext == ".vtt":
165
- subtitle = get_serialized_vtt(parsed_dicts)
166
-
167
- if add_timestamp:
168
- timestamp = datetime.now().strftime("%m%d%H%M%S")
169
- file_name += f"-{timestamp}"
170
-
171
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
172
- write_file(subtitle, output_path)
173
 
174
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
175
 
 
139
  )
140
 
141
  files_info = {}
142
+ for file_path in fileobjs:
143
+ file_name, file_ext = os.path.splitext(os.path.basename(file_path))
144
+ writer = get_writer(file_ext, self.output_dir)
145
+ segments = writer.to_segments(file_path)
 
 
 
 
 
146
 
147
  batch_size = self.max_text_batch_size
148
+ for batch_start in range(0, len(segments), batch_size):
149
+ progress(batch_start / len(segments), desc="Translating..")
150
+ sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
151
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
152
  target_lang, is_pro)
153
  for i, translated_text in enumerate(translated_texts):
154
+ segments[batch_start + i].text = translated_text["text"]
155
+
156
+ subtitle, output_path = generate_file(
157
+ output_dir=self.output_dir,
158
+ output_file_name=file_name,
159
+ output_format=file_ext,
160
+ result=segments,
161
+ add_timestamp=add_timestamp
162
+ )
 
 
 
 
 
163
 
164
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
165
 
modules/translation/translation_base.py CHANGED
@@ -95,32 +95,22 @@ class TranslationBase(ABC):
95
  files_info = {}
96
  for fileobj in fileobjs:
97
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
98
- if file_ext == ".srt":
99
- parsed_dicts = parse_srt(file_path=fileobj)
100
- total_progress = len(parsed_dicts)
101
- for index, dic in enumerate(parsed_dicts):
102
- progress(index / total_progress, desc="Translating..")
103
- translated_text = self.translate(dic["sentence"], max_length=max_length)
104
- dic["sentence"] = translated_text
105
- subtitle = get_serialized_srt(parsed_dicts)
106
-
107
- elif file_ext == ".vtt":
108
- parsed_dicts = parse_vtt(file_path=fileobj)
109
- total_progress = len(parsed_dicts)
110
- for index, dic in enumerate(parsed_dicts):
111
- progress(index / total_progress, desc="Translating..")
112
- translated_text = self.translate(dic["sentence"], max_length=max_length)
113
- dic["sentence"] = translated_text
114
- subtitle = get_serialized_vtt(parsed_dicts)
115
-
116
- if add_timestamp:
117
- timestamp = datetime.now().strftime("%m%d%H%M%S")
118
- file_name += f"-{timestamp}"
119
-
120
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
121
- write_file(subtitle, output_path)
122
-
123
- files_info[file_name] = {"subtitle": subtitle, "path": output_path}
124
 
125
  total_result = ''
126
  for file_name, info in files_info.items():
@@ -133,7 +123,8 @@ class TranslationBase(ABC):
133
  return [gr_str, output_file_paths]
134
 
135
  except Exception as e:
136
- print(f"Error: {str(e)}")
 
137
  finally:
138
  self.release_cuda_memory()
139
 
 
95
  files_info = {}
96
  for fileobj in fileobjs:
97
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
98
+ writer = get_writer(file_ext, self.output_dir)
99
+ segments = writer.to_segments(fileobj)
100
+ for i, segment in enumerate(segments):
101
+ progress(i / len(segments), desc="Translating..")
102
+ translated_text = self.translate(segment.text, max_length=max_length)
103
+ segment.text = translated_text
104
+
105
+ subtitle, file_path = generate_file(
106
+ output_dir=self.output_dir,
107
+ output_file_name=file_name,
108
+ output_format=file_ext,
109
+ result=segments,
110
+ add_timestamp=add_timestamp
111
+ )
112
+
113
+ files_info[file_name] = {"subtitle": subtitle, "path": file_path}
 
 
 
 
 
 
 
 
 
 
114
 
115
  total_result = ''
116
  for file_name, info in files_info.items():
 
123
  return [gr_str, output_file_paths]
124
 
125
  except Exception as e:
126
+ print(f"Error translating file: {e}")
127
+ raise
128
  finally:
129
  self.release_cuda_memory()
130
 
modules/utils/files_manager.py CHANGED
@@ -67,3 +67,9 @@ def is_video(file_path):
67
  video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
68
  extension = os.path.splitext(file_path)[1].lower()
69
  return extension in video_extensions
 
 
 
 
 
 
 
67
  video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
68
  extension = os.path.splitext(file_path)[1].lower()
69
  return extension in video_extensions
70
+
71
+
72
+ def read_file(file_path):
73
+ with open(file_path, "r", encoding="utf-8") as f:
74
+ subtitle_content = f.read()
75
+ return subtitle_content
modules/utils/subtitle_manager.py CHANGED
@@ -1,128 +1,424 @@
1
- import re
2
-
3
- from modules.whisper.data_classes import Segment
4
-
5
-
6
- def timeformat_srt(time):
7
- hours = time // 3600
8
- minutes = (time - hours * 3600) // 60
9
- seconds = time - hours * 3600 - minutes * 60
10
- milliseconds = (time - int(time)) * 1000
11
- return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
12
-
13
-
14
- def timeformat_vtt(time):
15
- hours = time // 3600
16
- minutes = (time - hours * 3600) // 60
17
- seconds = time - hours * 3600 - minutes * 60
18
- milliseconds = (time - int(time)) * 1000
19
- return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
20
-
21
-
22
- def write_file(subtitle, output_file):
23
- with open(output_file, 'w', encoding='utf-8') as f:
24
- f.write(subtitle)
25
-
26
-
27
- def get_srt(segments):
28
- if segments and isinstance(segments[0], Segment):
29
- segments = [seg.dict() for seg in segments]
30
-
31
- output = ""
32
- for i, segment in enumerate(segments):
33
- output += f"{i + 1}\n"
34
- output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
35
- if segment['text'].startswith(' '):
36
- segment['text'] = segment['text'][1:]
37
- output += f"{segment['text']}\n\n"
38
- return output
39
-
40
-
41
- def get_vtt(segments):
42
- if segments and isinstance(segments[0], Segment):
43
- segments = [seg.dict() for seg in segments]
44
-
45
- output = "WEBVTT\n\n"
46
- for i, segment in enumerate(segments):
47
- output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
48
- if segment['text'].startswith(' '):
49
- segment['text'] = segment['text'][1:]
50
- output += f"{segment['text']}\n\n"
51
- return output
52
-
53
-
54
- def get_txt(segments):
55
- if segments and isinstance(segments[0], Segment):
56
- segments = [seg.dict() for seg in segments]
57
-
58
- output = ""
59
- for i, segment in enumerate(segments):
60
- if segment['text'].startswith(' '):
61
- segment['text'] = segment['text'][1:]
62
- output += f"{segment['text']}\n"
63
- return output
64
 
 
 
 
 
 
 
 
65
 
66
- def parse_srt(file_path):
67
- """Reads SRT file and returns as dict"""
68
- with open(file_path, 'r', encoding='utf-8') as file:
69
- srt_data = file.read()
70
 
71
- data = []
72
- blocks = srt_data.split('\n\n')
73
 
74
- for block in blocks:
75
- if block.strip() != '':
76
- lines = block.strip().split('\n')
77
- index = lines[0]
78
- timestamp = lines[1]
79
- sentence = ' '.join(lines[2:])
80
 
81
- data.append({
82
- "index": index,
83
- "timestamp": timestamp,
84
- "sentence": sentence
85
- })
86
- return data
87
 
 
 
88
 
89
- def parse_vtt(file_path):
90
- """Reads WEBVTT file and returns as dict"""
91
- with open(file_path, 'r', encoding='utf-8') as file:
92
- webvtt_data = file.read()
93
 
94
- data = []
95
- blocks = webvtt_data.split('\n\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- for block in blocks:
98
- if block.strip() != '' and not block.strip().startswith("WEBVTT"):
99
- lines = block.strip().split('\n')
100
- timestamp = lines[0]
101
- sentence = ' '.join(lines[1:])
102
-
103
- data.append({
104
- "timestamp": timestamp,
105
- "sentence": sentence
106
- })
107
-
108
- return data
109
-
110
-
111
- def get_serialized_srt(dicts):
112
- output = ""
113
- for dic in dicts:
114
- output += f'{dic["index"]}\n'
115
- output += f'{dic["timestamp"]}\n'
116
- output += f'{dic["sentence"]}\n\n'
117
- return output
118
 
 
 
119
 
120
- def get_serialized_vtt(dicts):
121
- output = "WEBVTT\n\n"
122
- for dic in dicts:
123
- output += f'{dic["timestamp"]}\n'
124
- output += f'{dic["sentence"]}\n\n'
125
- return output
126
 
127
 
128
  def safe_filename(name):
 
1
+ # Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import json
4
+ import os
5
+ import re
6
+ import sys
7
+ import zlib
8
+ from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
9
+ from datetime import datetime
10
 
11
+ from modules.whisper.data_classes import Segment, Word
12
+ from .files_manager import read_file
 
 
13
 
 
 
14
 
15
+ def format_timestamp(
16
+ seconds: float, always_include_hours: bool = True, decimal_marker: str = ","
17
+ ) -> str:
18
+ assert seconds >= 0, "non-negative timestamp expected"
19
+ milliseconds = round(seconds * 1000.0)
 
20
 
21
+ hours = milliseconds // 3_600_000
22
+ milliseconds -= hours * 3_600_000
 
 
 
 
23
 
24
+ minutes = milliseconds // 60_000
25
+ milliseconds -= minutes * 60_000
26
 
27
+ seconds = milliseconds // 1_000
28
+ milliseconds -= seconds * 1_000
 
 
29
 
30
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
31
+ return (
32
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
33
+ )
34
+
35
+
36
+ def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float:
37
+ times = time_str.split(":")
38
+
39
+ if len(times) == 3:
40
+ hours, minutes, rest = times
41
+ hours = int(hours)
42
+ else:
43
+ hours = 0
44
+ minutes, rest = times
45
+
46
+ seconds, fractional = rest.split(decimal_marker)
47
+
48
+ minutes = int(minutes)
49
+ seconds = int(seconds)
50
+ fractional_seconds = float("0." + fractional)
51
+
52
+ return hours * 3600 + minutes * 60 + seconds + fractional_seconds
53
+
54
+
55
+ def get_start(segments: List[dict]) -> Optional[float]:
56
+ return next(
57
+ (w["start"] for s in segments for w in s["words"]),
58
+ segments[0]["start"] if segments else None,
59
+ )
60
+
61
+
62
+ def get_end(segments: List[dict]) -> Optional[float]:
63
+ return next(
64
+ (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
65
+ segments[-1]["end"] if segments else None,
66
+ )
67
+
68
+
69
+ class ResultWriter:
70
+ extension: str
71
+
72
+ def __init__(self, output_dir: str):
73
+ self.output_dir = output_dir
74
+
75
+ def __call__(
76
+ self, result: Union[dict, List[Segment]], output_file_name: str,
77
+ options: Optional[dict] = None, **kwargs
78
+ ):
79
+ if isinstance(result, List) and result and isinstance(result[0], Segment):
80
+ result = {"segments": [seg.model_dump() for seg in result]}
81
+
82
+ output_path = os.path.join(
83
+ self.output_dir, output_file_name + "." + self.extension
84
+ )
85
+
86
+ with open(output_path, "w", encoding="utf-8") as f:
87
+ self.write_result(result, file=f, options=options, **kwargs)
88
+
89
+ def write_result(
90
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
91
+ ):
92
+ raise NotImplementedError
93
+
94
+
95
+ class WriteTXT(ResultWriter):
96
+ extension: str = "txt"
97
+
98
+ def write_result(
99
+ self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs
100
+ ):
101
+ for segment in result["segments"]:
102
+ print(segment["text"].strip(), file=file, flush=True)
103
+
104
+
105
+ class SubtitlesWriter(ResultWriter):
106
+ always_include_hours: bool
107
+ decimal_marker: str
108
+
109
+ def iterate_result(
110
+ self,
111
+ result: dict,
112
+ options: Optional[dict] = None,
113
+ *,
114
+ max_line_width: Optional[int] = None,
115
+ max_line_count: Optional[int] = None,
116
+ highlight_words: bool = False,
117
+ align_lrc_words: bool = False,
118
+ max_words_per_line: Optional[int] = None,
119
+ ):
120
+ options = options or {}
121
+ max_line_width = max_line_width or options.get("max_line_width")
122
+ max_line_count = max_line_count or options.get("max_line_count")
123
+ highlight_words = highlight_words or options.get("highlight_words", False)
124
+ align_lrc_words = align_lrc_words or options.get("align_lrc_words", False)
125
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
126
+ preserve_segments = max_line_count is None or max_line_width is None
127
+ max_line_width = max_line_width or 1000
128
+ max_words_per_line = max_words_per_line or 1000
129
+
130
+ def iterate_subtitles():
131
+ line_len = 0
132
+ line_count = 1
133
+ # the next subtitle to yield (a list of word timings with whitespace)
134
+ subtitle: List[dict] = []
135
+ last: float = get_start(result["segments"]) or 0.0
136
+ for segment in result["segments"]:
137
+ chunk_index = 0
138
+ words_count = max_words_per_line
139
+ while chunk_index < len(segment["words"]):
140
+ remaining_words = len(segment["words"]) - chunk_index
141
+ if max_words_per_line > len(segment["words"]) - chunk_index:
142
+ words_count = remaining_words
143
+ for i, original_timing in enumerate(
144
+ segment["words"][chunk_index : chunk_index + words_count]
145
+ ):
146
+ timing = original_timing.copy()
147
+ long_pause = (
148
+ not preserve_segments and timing["start"] - last > 3.0
149
+ )
150
+ has_room = line_len + len(timing["word"]) <= max_line_width
151
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
152
+ if (
153
+ line_len > 0
154
+ and has_room
155
+ and not long_pause
156
+ and not seg_break
157
+ ):
158
+ # line continuation
159
+ line_len += len(timing["word"])
160
+ else:
161
+ # new line
162
+ timing["word"] = timing["word"].strip()
163
+ if (
164
+ len(subtitle) > 0
165
+ and max_line_count is not None
166
+ and (long_pause or line_count >= max_line_count)
167
+ or seg_break
168
+ ):
169
+ # subtitle break
170
+ yield subtitle
171
+ subtitle = []
172
+ line_count = 1
173
+ elif line_len > 0:
174
+ # line break
175
+ line_count += 1
176
+ timing["word"] = "\n" + timing["word"]
177
+ line_len = len(timing["word"].strip())
178
+ subtitle.append(timing)
179
+ last = timing["start"]
180
+ chunk_index += max_words_per_line
181
+ if len(subtitle) > 0:
182
+ yield subtitle
183
+
184
+ if len(result["segments"]) > 0 and "words" in result["segments"][0] and result["segments"][0]["words"]:
185
+ for subtitle in iterate_subtitles():
186
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
187
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
188
+ subtitle_text = "".join([word["word"] for word in subtitle])
189
+ if highlight_words:
190
+ last = subtitle_start
191
+ all_words = [timing["word"] for timing in subtitle]
192
+ for i, this_word in enumerate(subtitle):
193
+ start = self.format_timestamp(this_word["start"])
194
+ end = self.format_timestamp(this_word["end"])
195
+ if last != start:
196
+ yield last, start, subtitle_text
197
+
198
+ yield start, end, "".join(
199
+ [
200
+ re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
201
+ if j == i
202
+ else word
203
+ for j, word in enumerate(all_words)
204
+ ]
205
+ )
206
+ last = end
207
+
208
+ if align_lrc_words:
209
+ lrc_aligned_words = [f"[{self.format_timestamp(sub['start'])}]{sub['word']}" for sub in subtitle]
210
+ l_start, l_end = self.format_timestamp(subtitle[-1]['start']), self.format_timestamp(subtitle[-1]['end'])
211
+ lrc_aligned_words[-1] = f"[{l_start}]{subtitle[-1]['word']}[{l_end}]"
212
+ lrc_aligned_words = ' '.join(lrc_aligned_words)
213
+ yield None, None, lrc_aligned_words
214
+
215
+ else:
216
+ yield subtitle_start, subtitle_end, subtitle_text
217
+ else:
218
+ for segment in result["segments"]:
219
+ segment_start = self.format_timestamp(segment["start"])
220
+ segment_end = self.format_timestamp(segment["end"])
221
+ segment_text = segment["text"].strip().replace("-->", "->")
222
+ yield segment_start, segment_end, segment_text
223
+
224
+ def format_timestamp(self, seconds: float):
225
+ return format_timestamp(
226
+ seconds=seconds,
227
+ always_include_hours=self.always_include_hours,
228
+ decimal_marker=self.decimal_marker,
229
+ )
230
+
231
+
232
+ class WriteVTT(SubtitlesWriter):
233
+ extension: str = "vtt"
234
+ always_include_hours: bool = False
235
+ decimal_marker: str = "."
236
+
237
+ def write_result(
238
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
239
+ ):
240
+ print("WEBVTT\n", file=file)
241
+ for start, end, text in self.iterate_result(result, options, **kwargs):
242
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
243
+
244
+ def to_segments(self, file_path: str) -> List[Segment]:
245
+ segments = []
246
+
247
+ blocks = read_file(file_path).split('\n\n')
248
+
249
+ for block in blocks:
250
+ if block.strip() != '' and not block.strip().startswith("WEBVTT"):
251
+ lines = block.strip().split('\n')
252
+ time_line = lines[0].split(" --> ")
253
+ start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
254
+ sentence = ' '.join(lines[1:])
255
+
256
+ segments.append(Segment(
257
+ start=start,
258
+ end=end,
259
+ text=sentence
260
+ ))
261
+
262
+ return segments
263
+
264
+
265
+ class WriteSRT(SubtitlesWriter):
266
+ extension: str = "srt"
267
+ always_include_hours: bool = True
268
+ decimal_marker: str = ","
269
+
270
+ def write_result(
271
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
272
+ ):
273
+ for i, (start, end, text) in enumerate(
274
+ self.iterate_result(result, options, **kwargs), start=1
275
+ ):
276
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
277
+
278
+ def to_segments(self, file_path: str) -> List[Segment]:
279
+ segments = []
280
+
281
+ blocks = read_file(file_path).split('\n\n')
282
+
283
+ for block in blocks:
284
+ if block.strip() != '':
285
+ lines = block.strip().split('\n')
286
+ index = lines[0]
287
+ time_line = lines[1].split(" --> ")
288
+ start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
289
+ sentence = ' '.join(lines[2:])
290
+
291
+ segments.append(Segment(
292
+ start=start,
293
+ end=end,
294
+ text=sentence
295
+ ))
296
+
297
+ return segments
298
+
299
+
300
+ class WriteLRC(SubtitlesWriter):
301
+ extension: str = "lrc"
302
+ always_include_hours: bool = False
303
+ decimal_marker: str = "."
304
+
305
+ def write_result(
306
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
307
+ ):
308
+ for i, (start, end, text) in enumerate(
309
+ self.iterate_result(result, options, **kwargs), start=1
310
+ ):
311
+ if "align_lrc_words" in kwargs and kwargs["align_lrc_words"]:
312
+ print(f"{text}\n", file=file, flush=True)
313
+ else:
314
+ print(f"[{start}]{text}[{end}]\n", file=file, flush=True)
315
+
316
+ def to_segments(self, file_path: str) -> List[Segment]:
317
+ segments = []
318
+
319
+ blocks = read_file(file_path).split('\n')
320
+
321
+ for block in blocks:
322
+ if block.strip() != '':
323
+ lines = block.strip()
324
+ pattern = r'(\[.*?\])'
325
+ parts = re.split(pattern, lines)
326
+ parts = [part.strip() for part in parts if part]
327
+
328
+ for i, part in enumerate(parts):
329
+ sentence_i = i%2
330
+ if sentence_i == 1:
331
+ start_str, text, end_str = parts[sentence_i-1], parts[sentence_i], parts[sentence_i+1]
332
+ start_str, end_str = start_str.replace("[", "").replace("]", ""), end_str.replace("[", "").replace("]", "")
333
+ start, end = time_str_to_seconds(start_str, self.decimal_marker), time_str_to_seconds(end_str, self.decimal_marker)
334
+
335
+ segments.append(Segment(
336
+ start=start,
337
+ end=end,
338
+ text=text,
339
+ ))
340
+
341
+ return segments
342
+
343
+
344
+ class WriteTSV(ResultWriter):
345
+ """
346
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
347
+ <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
348
+
349
+ Using integer milliseconds as start and end times means there's no chance of interference from
350
+ an environment setting a language encoding that causes the decimal in a floating point number
351
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
352
+ """
353
+
354
+ extension: str = "tsv"
355
+
356
+ def write_result(
357
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
358
+ ):
359
+ print("start", "end", "text", sep="\t", file=file)
360
+ for segment in result["segments"]:
361
+ print(round(1000 * segment["start"]), file=file, end="\t")
362
+ print(round(1000 * segment["end"]), file=file, end="\t")
363
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
364
+
365
+
366
+ class WriteJSON(ResultWriter):
367
+ extension: str = "json"
368
+
369
+ def write_result(
370
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
371
+ ):
372
+ json.dump(result, file)
373
+
374
+
375
+ def get_writer(
376
+ output_format: str, output_dir: str
377
+ ) -> Callable[[dict, TextIO, dict], None]:
378
+ output_format = output_format.strip().lower().replace(".", "")
379
+
380
+ writers = {
381
+ "txt": WriteTXT,
382
+ "vtt": WriteVTT,
383
+ "srt": WriteSRT,
384
+ "tsv": WriteTSV,
385
+ "json": WriteJSON,
386
+ "lrc": WriteLRC
387
+ }
388
+
389
+ if output_format == "all":
390
+ all_writers = [writer(output_dir) for writer in writers.values()]
391
+
392
+ def write_all(
393
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
394
+ ):
395
+ for writer in all_writers:
396
+ writer(result, file, options, **kwargs)
397
+
398
+ return write_all
399
+
400
+ return writers[output_format](output_dir)
401
+
402
+
403
+ def generate_file(
404
+ output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str,
405
+ add_timestamp: bool = True, **kwargs
406
+ ) -> Tuple[str, str]:
407
+ output_format = output_format.strip().lower().replace(".", "")
408
+
409
+ if add_timestamp:
410
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
411
+ output_file_name += f"-{timestamp}"
412
 
413
+ file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
414
+ file_writer = get_writer(output_format=output_format, output_dir=output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
+ if isinstance(file_writer, WriteLRC) and kwargs.get("highlight_words", False):
417
+ kwargs["highlight_words"], kwargs["align_lrc_words"] = False, True
418
 
419
+ file_writer(result=result, output_file_name=output_file_name, **kwargs)
420
+ content = read_file(file_path)
421
+ return content, file_path
 
 
 
422
 
423
 
424
  def safe_filename(name):
modules/vad/silero_vad.py CHANGED
@@ -259,7 +259,7 @@ class SileroVAD:
259
 
260
  for segment in segments:
261
  segment.start = ts_map.get_original_time(segment.start)
262
- segment.start = ts_map.get_original_time(segment.start)
263
 
264
  return segments
265
 
 
259
 
260
  for segment in segments:
261
  segment.start = ts_map.get_original_time(segment.start)
262
+ segment.end = ts_map.get_original_time(segment.end)
263
 
264
  return segments
265
 
modules/whisper/base_transcription_pipeline.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- import torch
3
- import ast
4
  import whisper
5
  import ctranslate2
6
  import gradio as gr
@@ -10,15 +8,14 @@ from typing import BinaryIO, Union, Tuple, List
10
  import numpy as np
11
  from datetime import datetime
12
  from faster_whisper.vad import VadOptions
13
- from dataclasses import astuple
14
 
15
  from modules.uvr.music_separator import MusicSeparator
16
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
  UVR_MODELS_DIR)
18
  from modules.utils.constants import *
19
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
20
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
21
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
22
  from modules.whisper.data_classes import *
23
  from modules.diarize.diarizer import Diarizer
24
  from modules.vad.silero_vad import SileroVAD
@@ -76,7 +73,7 @@ class BaseTranscriptionPipeline(ABC):
76
  progress: gr.Progress = gr.Progress(),
77
  add_timestamp: bool = True,
78
  *pipeline_params,
79
- ) -> Tuple[List[dict], float]:
80
  """
81
  Run transcription with conditional pre-processing and post-processing.
82
  The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
@@ -92,12 +89,14 @@ class BaseTranscriptionPipeline(ABC):
92
  add_timestamp: bool
93
  Whether to add a timestamp at the end of the filename.
94
  *pipeline_params: tuple
95
- Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
 
 
96
 
97
  Returns
98
  ----------
99
- segments_result: List[dict]
100
- list of dicts that includes start, end timestamps and transcribed text
101
  elapsed_time: float
102
  elapsed time for running
103
  """
@@ -179,7 +178,7 @@ class BaseTranscriptionPipeline(ABC):
179
  file_format: str = "SRT",
180
  add_timestamp: bool = True,
181
  progress=gr.Progress(),
182
- *params,
183
  ) -> list:
184
  """
185
  Write subtitle file from Files
@@ -197,7 +196,7 @@ class BaseTranscriptionPipeline(ABC):
197
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
198
  progress: gr.Progress
199
  Indicator to show progress directly in gradio.
200
- *params: tuple
201
  Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
202
 
203
  Returns
@@ -208,6 +207,11 @@ class BaseTranscriptionPipeline(ABC):
208
  Output file path to return to gr.Files()
209
  """
210
  try:
 
 
 
 
 
211
  if input_folder_path:
212
  files = get_media_files(input_folder_path)
213
  if isinstance(files, str):
@@ -221,18 +225,19 @@ class BaseTranscriptionPipeline(ABC):
221
  file,
222
  progress,
223
  add_timestamp,
224
- *params,
225
  )
226
 
227
  file_name, file_ext = os.path.splitext(os.path.basename(file))
228
- subtitle, file_path = self.generate_and_write_file(
229
- file_name=file_name,
230
- transcribed_segments=transcribed_segments,
 
 
231
  add_timestamp=add_timestamp,
232
- file_format=file_format,
233
- output_dir=self.output_dir
234
  )
235
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
236
 
237
  total_result = ''
238
  total_time = 0
@@ -249,6 +254,7 @@ class BaseTranscriptionPipeline(ABC):
249
 
250
  except Exception as e:
251
  print(f"Error transcribing file: {e}")
 
252
  finally:
253
  self.release_cuda_memory()
254
 
@@ -257,7 +263,7 @@ class BaseTranscriptionPipeline(ABC):
257
  file_format: str = "SRT",
258
  add_timestamp: bool = True,
259
  progress=gr.Progress(),
260
- *whisper_params,
261
  ) -> list:
262
  """
263
  Write subtitle file from microphone
@@ -272,7 +278,7 @@ class BaseTranscriptionPipeline(ABC):
272
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
273
  progress: gr.Progress
274
  Indicator to show progress directly in gradio.
275
- *whisper_params: tuple
276
  Parameters related with whisper. This will be dealt with "WhisperParameters" data class
277
 
278
  Returns
@@ -283,27 +289,35 @@ class BaseTranscriptionPipeline(ABC):
283
  Output file path to return to gr.Files()
284
  """
285
  try:
 
 
 
 
 
286
  progress(0, desc="Loading Audio..")
287
  transcribed_segments, time_for_task = self.run(
288
  mic_audio,
289
  progress,
290
  add_timestamp,
291
- *whisper_params,
292
  )
293
  progress(1, desc="Completed!")
294
 
295
- subtitle, result_file_path = self.generate_and_write_file(
296
- file_name="Mic",
297
- transcribed_segments=transcribed_segments,
 
 
 
298
  add_timestamp=add_timestamp,
299
- file_format=file_format,
300
- output_dir=self.output_dir
301
  )
302
 
303
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
304
- return [result_str, result_file_path]
305
  except Exception as e:
306
- print(f"Error transcribing file: {e}")
 
307
  finally:
308
  self.release_cuda_memory()
309
 
@@ -312,7 +326,7 @@ class BaseTranscriptionPipeline(ABC):
312
  file_format: str = "SRT",
313
  add_timestamp: bool = True,
314
  progress=gr.Progress(),
315
- *whisper_params,
316
  ) -> list:
317
  """
318
  Write subtitle file from Youtube
@@ -327,7 +341,7 @@ class BaseTranscriptionPipeline(ABC):
327
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
328
  progress: gr.Progress
329
  Indicator to show progress directly in gradio.
330
- *whisper_params: tuple
331
  Parameters related with whisper. This will be dealt with "WhisperParameters" data class
332
 
333
  Returns
@@ -338,6 +352,11 @@ class BaseTranscriptionPipeline(ABC):
338
  Output file path to return to gr.Files()
339
  """
340
  try:
 
 
 
 
 
341
  progress(0, desc="Loading Audio from Youtube..")
342
  yt = get_ytdata(youtube_link)
343
  audio = get_ytaudio(yt)
@@ -346,28 +365,31 @@ class BaseTranscriptionPipeline(ABC):
346
  audio,
347
  progress,
348
  add_timestamp,
349
- *whisper_params,
350
  )
351
 
352
  progress(1, desc="Completed!")
353
 
354
  file_name = safe_filename(yt.title)
355
- subtitle, result_file_path = self.generate_and_write_file(
356
- file_name=file_name,
357
- transcribed_segments=transcribed_segments,
 
 
358
  add_timestamp=add_timestamp,
359
- file_format=file_format,
360
- output_dir=self.output_dir
361
  )
 
362
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
363
 
364
  if os.path.exists(audio):
365
  os.remove(audio)
366
 
367
- return [result_str, result_file_path]
368
 
369
  except Exception as e:
370
- print(f"Error transcribing file: {e}")
 
371
  finally:
372
  self.release_cuda_memory()
373
 
@@ -385,58 +407,6 @@ class BaseTranscriptionPipeline(ABC):
385
  else:
386
  return list(ctranslate2.get_supported_compute_types("cpu"))
387
 
388
- @staticmethod
389
- def generate_and_write_file(file_name: str,
390
- transcribed_segments: list,
391
- add_timestamp: bool,
392
- file_format: str,
393
- output_dir: str
394
- ) -> str:
395
- """
396
- Writes subtitle file
397
-
398
- Parameters
399
- ----------
400
- file_name: str
401
- Output file name
402
- transcribed_segments: list
403
- Text segments transcribed from audio
404
- add_timestamp: bool
405
- Determines whether to add a timestamp to the end of the filename.
406
- file_format: str
407
- File format to write. Supported formats: [SRT, WebVTT, txt]
408
- output_dir: str
409
- Directory path of the output
410
-
411
- Returns
412
- ----------
413
- content: str
414
- Result of the transcription
415
- output_path: str
416
- output file path
417
- """
418
- if add_timestamp:
419
- timestamp = datetime.now().strftime("%m%d%H%M%S")
420
- output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
421
- else:
422
- output_path = os.path.join(output_dir, f"{file_name}")
423
-
424
- file_format = file_format.strip().lower()
425
- if file_format == "srt":
426
- content = get_srt(transcribed_segments)
427
- output_path += '.srt'
428
-
429
- elif file_format == "webvtt":
430
- content = get_vtt(transcribed_segments)
431
- output_path += '.vtt'
432
-
433
- elif file_format == "txt":
434
- content = get_txt(transcribed_segments)
435
- output_path += '.txt'
436
-
437
- write_file(content, output_path)
438
- return content, output_path
439
-
440
  @staticmethod
441
  def format_time(elapsed_time: float) -> str:
442
  """
 
1
  import os
 
 
2
  import whisper
3
  import ctranslate2
4
  import gradio as gr
 
8
  import numpy as np
9
  from datetime import datetime
10
  from faster_whisper.vad import VadOptions
 
11
 
12
  from modules.uvr.music_separator import MusicSeparator
13
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
14
  UVR_MODELS_DIR)
15
  from modules.utils.constants import *
16
+ from modules.utils.subtitle_manager import *
17
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
19
  from modules.whisper.data_classes import *
20
  from modules.diarize.diarizer import Diarizer
21
  from modules.vad.silero_vad import SileroVAD
 
73
  progress: gr.Progress = gr.Progress(),
74
  add_timestamp: bool = True,
75
  *pipeline_params,
76
+ ) -> Tuple[List[Segment], float]:
77
  """
78
  Run transcription with conditional pre-processing and post-processing.
79
  The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
 
89
  add_timestamp: bool
90
  Whether to add a timestamp at the end of the filename.
91
  *pipeline_params: tuple
92
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class.
93
+ This must be provided as a List with * wildcard because of the integration with gradio.
94
+ See more info at : https://github.com/gradio-app/gradio/issues/2471
95
 
96
  Returns
97
  ----------
98
+ segments_result: List[Segment]
99
+ list of Segment that includes start, end timestamps and transcribed text
100
  elapsed_time: float
101
  elapsed time for running
102
  """
 
178
  file_format: str = "SRT",
179
  add_timestamp: bool = True,
180
  progress=gr.Progress(),
181
+ *pipeline_params,
182
  ) -> list:
183
  """
184
  Write subtitle file from Files
 
196
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
197
  progress: gr.Progress
198
  Indicator to show progress directly in gradio.
199
+ *pipeline_params: tuple
200
  Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
201
 
202
  Returns
 
207
  Output file path to return to gr.Files()
208
  """
209
  try:
210
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
211
+ writer_options = {
212
+ "highlight_words": True if params.whisper.word_timestamps else False
213
+ }
214
+
215
  if input_folder_path:
216
  files = get_media_files(input_folder_path)
217
  if isinstance(files, str):
 
225
  file,
226
  progress,
227
  add_timestamp,
228
+ *pipeline_params,
229
  )
230
 
231
  file_name, file_ext = os.path.splitext(os.path.basename(file))
232
+ subtitle, file_path = generate_file(
233
+ output_dir=self.output_dir,
234
+ output_file_name=file_name,
235
+ output_format=file_format,
236
+ result=transcribed_segments,
237
  add_timestamp=add_timestamp,
238
+ **writer_options
 
239
  )
240
+ files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
241
 
242
  total_result = ''
243
  total_time = 0
 
254
 
255
  except Exception as e:
256
  print(f"Error transcribing file: {e}")
257
+ raise
258
  finally:
259
  self.release_cuda_memory()
260
 
 
263
  file_format: str = "SRT",
264
  add_timestamp: bool = True,
265
  progress=gr.Progress(),
266
+ *pipeline_params,
267
  ) -> list:
268
  """
269
  Write subtitle file from microphone
 
278
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
279
  progress: gr.Progress
280
  Indicator to show progress directly in gradio.
281
+ *pipeline_params: tuple
282
  Parameters related with whisper. This will be dealt with "WhisperParameters" data class
283
 
284
  Returns
 
289
  Output file path to return to gr.Files()
290
  """
291
  try:
292
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
293
+ writer_options = {
294
+ "highlight_words": True if params.whisper.word_timestamps else False
295
+ }
296
+
297
  progress(0, desc="Loading Audio..")
298
  transcribed_segments, time_for_task = self.run(
299
  mic_audio,
300
  progress,
301
  add_timestamp,
302
+ *pipeline_params,
303
  )
304
  progress(1, desc="Completed!")
305
 
306
+ file_name = "Mic"
307
+ subtitle, file_path = generate_file(
308
+ output_dir=self.output_dir,
309
+ output_file_name=file_name,
310
+ output_format=file_format,
311
+ result=transcribed_segments,
312
  add_timestamp=add_timestamp,
313
+ **writer_options
 
314
  )
315
 
316
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
317
+ return [result_str, file_path]
318
  except Exception as e:
319
+ print(f"Error transcribing mic: {e}")
320
+ raise
321
  finally:
322
  self.release_cuda_memory()
323
 
 
326
  file_format: str = "SRT",
327
  add_timestamp: bool = True,
328
  progress=gr.Progress(),
329
+ *pipeline_params,
330
  ) -> list:
331
  """
332
  Write subtitle file from Youtube
 
341
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
342
  progress: gr.Progress
343
  Indicator to show progress directly in gradio.
344
+ *pipeline_params: tuple
345
  Parameters related with whisper. This will be dealt with "WhisperParameters" data class
346
 
347
  Returns
 
352
  Output file path to return to gr.Files()
353
  """
354
  try:
355
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
356
+ writer_options = {
357
+ "highlight_words": True if params.whisper.word_timestamps else False
358
+ }
359
+
360
  progress(0, desc="Loading Audio from Youtube..")
361
  yt = get_ytdata(youtube_link)
362
  audio = get_ytaudio(yt)
 
365
  audio,
366
  progress,
367
  add_timestamp,
368
+ *pipeline_params,
369
  )
370
 
371
  progress(1, desc="Completed!")
372
 
373
  file_name = safe_filename(yt.title)
374
+ subtitle, file_path = generate_file(
375
+ output_dir=self.output_dir,
376
+ output_file_name=file_name,
377
+ output_format=file_format,
378
+ result=transcribed_segments,
379
  add_timestamp=add_timestamp,
380
+ **writer_options
 
381
  )
382
+
383
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
384
 
385
  if os.path.exists(audio):
386
  os.remove(audio)
387
 
388
+ return [result_str, file_path]
389
 
390
  except Exception as e:
391
+ print(f"Error transcribing youtube: {e}")
392
+ raise
393
  finally:
394
  self.release_cuda_memory()
395
 
 
407
  else:
408
  return list(ctranslate2.get_supported_compute_types("cpu"))
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  @staticmethod
411
  def format_time(elapsed_time: float) -> str:
412
  """
modules/whisper/data_classes.py CHANGED
@@ -1,10 +1,12 @@
 
1
  import gradio as gr
2
  import torch
3
- from typing import Optional, Dict, List, Union
4
  from pydantic import BaseModel, Field, field_validator, ConfigDict
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
7
  from copy import deepcopy
 
8
  import yaml
9
 
10
  from modules.utils.constants import *
@@ -17,12 +19,53 @@ class WhisperImpl(Enum):
17
 
18
 
19
  class Segment(BaseModel):
20
- text: Optional[str] = Field(default=None,
21
- description="Transcription text of the segment")
22
- start: Optional[float] = Field(default=None,
23
- description="Start time of the segment")
24
- end: Optional[float] = Field(default=None,
25
- description="End time of the segment")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  class BaseParams(BaseModel):
@@ -250,9 +293,9 @@ class WhisperParams(BaseParams):
250
  default=True,
251
  description="Suppress blank outputs at start of sampling"
252
  )
253
- suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
254
  max_initial_timestamp: float = Field(
255
- default=0.0,
256
  ge=0.0,
257
  description="Maximum initial timestamp"
258
  )
 
1
+ import faster_whisper.transcribe
2
  import gradio as gr
3
  import torch
4
+ from typing import Optional, Dict, List, Union, NamedTuple
5
  from pydantic import BaseModel, Field, field_validator, ConfigDict
6
  from gradio_i18n import Translate, gettext as _
7
  from enum import Enum
8
  from copy import deepcopy
9
+
10
  import yaml
11
 
12
  from modules.utils.constants import *
 
19
 
20
 
21
  class Segment(BaseModel):
22
+ id: Optional[int] = Field(default=None, description="Incremental id for the segment")
23
+ seek: Optional[int] = Field(default=None, description="Seek of the segment from chunked audio")
24
+ text: Optional[str] = Field(default=None, description="Transcription text of the segment")
25
+ start: Optional[float] = Field(default=None, description="Start time of the segment")
26
+ end: Optional[float] = Field(default=None, description="End time of the segment")
27
+ tokens: Optional[List[int]] = Field(default=None, description="List of token IDs")
28
+ temperature: Optional[float] = Field(default=None, description="Temperature used during the decoding process")
29
+ avg_logprob: Optional[float] = Field(default=None, description="Average log probability of the tokens")
30
+ compression_ratio: Optional[float] = Field(default=None, description="Compression ratio of the segment")
31
+ no_speech_prob: Optional[float] = Field(default=None, description="Probability that it's not speech")
32
+ words: Optional[List['Word']] = Field(default=None, description="List of words contained in the segment")
33
+
34
+ @classmethod
35
+ def from_faster_whisper(cls,
36
+ seg: faster_whisper.transcribe.Segment):
37
+ if seg.words is not None:
38
+ words = [
39
+ Word(
40
+ start=w.start,
41
+ end=w.end,
42
+ word=w.word,
43
+ probability=w.probability
44
+ ) for w in seg.words
45
+ ]
46
+ else:
47
+ words = None
48
+
49
+ return cls(
50
+ id=seg.id,
51
+ seek=seg.seek,
52
+ text=seg.text,
53
+ start=seg.start,
54
+ end=seg.end,
55
+ tokens=seg.tokens,
56
+ temperature=seg.temperature,
57
+ avg_logprob=seg.avg_logprob,
58
+ compression_ratio=seg.compression_ratio,
59
+ no_speech_prob=seg.no_speech_prob,
60
+ words=words
61
+ )
62
+
63
+
64
+ class Word(BaseModel):
65
+ start: Optional[float] = Field(default=None, description="Start time of the word")
66
+ end: Optional[float] = Field(default=None, description="Start time of the word")
67
+ word: Optional[str] = Field(default=None, description="Word text")
68
+ probability: Optional[float] = Field(default=None, description="Probability of the word")
69
 
70
 
71
  class BaseParams(BaseModel):
 
293
  default=True,
294
  description="Suppress blank outputs at start of sampling"
295
  )
296
+ suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress")
297
  max_initial_timestamp: float = Field(
298
+ default=1.0,
299
  ge=0.0,
300
  description="Maximum initial timestamp"
301
  )
modules/whisper/faster_whisper_inference.py CHANGED
@@ -102,11 +102,7 @@ class FasterWhisperInference(BaseTranscriptionPipeline):
102
  segments_result = []
103
  for segment in segments:
104
  progress(segment.start / info.duration, desc="Transcribing..")
105
- segments_result.append(Segment(
106
- start=segment.start,
107
- end=segment.end,
108
- text=segment.text
109
- ))
110
 
111
  elapsed_time = time.time() - start_time
112
  return segments_result, elapsed_time
 
102
  segments_result = []
103
  for segment in segments:
104
  progress(segment.start / info.duration, desc="Transcribing..")
105
+ segments_result.append(Segment.from_faster_whisper(segment))
 
 
 
 
106
 
107
  elapsed_time = time.time() - start_time
108
  return segments_result, elapsed_time