jhj0517 commited on
Commit
a1b32c1
·
1 Parent(s): a377305

Use `generate_file()`

Browse files
modules/utils/subtitle_manager.py CHANGED
@@ -5,10 +5,11 @@ import os
5
  import re
6
  import sys
7
  import zlib
8
- from typing import Callable, List, Optional, TextIO, Union, Dict
9
  from datetime import datetime
10
 
11
  from modules.whisper.data_classes import Segment
 
12
 
13
 
14
  def format_timestamp(
@@ -61,7 +62,7 @@ class ResultWriter:
61
 
62
  if add_timestamp:
63
  timestamp = datetime.now().strftime("%m%d%H%M%S")
64
- output_file_name += timestamp
65
 
66
  output_path = os.path.join(
67
  self.output_dir, output_file_name + "." + self.extension
@@ -264,6 +265,8 @@ class WriteJSON(ResultWriter):
264
  def get_writer(
265
  output_format: str, output_dir: str
266
  ) -> Callable[[dict, TextIO, dict], None]:
 
 
267
  writers = {
268
  "txt": WriteTXT,
269
  "vtt": WriteVTT,
@@ -286,6 +289,16 @@ def get_writer(
286
  return writers[output_format](output_dir)
287
 
288
 
 
 
 
 
 
 
 
 
 
 
289
  def parse_srt(file_path):
290
  """Reads SRT file and returns as dict"""
291
  with open(file_path, 'r', encoding='utf-8') as file:
 
5
  import re
6
  import sys
7
  import zlib
8
+ from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
9
  from datetime import datetime
10
 
11
  from modules.whisper.data_classes import Segment
12
+ from .files_manager import read_file
13
 
14
 
15
  def format_timestamp(
 
62
 
63
  if add_timestamp:
64
  timestamp = datetime.now().strftime("%m%d%H%M%S")
65
+ output_file_name += f"-{timestamp}"
66
 
67
  output_path = os.path.join(
68
  self.output_dir, output_file_name + "." + self.extension
 
265
  def get_writer(
266
  output_format: str, output_dir: str
267
  ) -> Callable[[dict, TextIO, dict], None]:
268
+ output_format = output_format.strip().lower()
269
+
270
  writers = {
271
  "txt": WriteTXT,
272
  "vtt": WriteVTT,
 
289
  return writers[output_format](output_dir)
290
 
291
 
292
+ def generate_file(
293
+ output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str, add_timestamp: bool = True,
294
+ ) -> Tuple[str, str]:
295
+ file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
296
+ file_writer = get_writer(output_format=output_format, output_dir=output_dir)
297
+ file_writer(result=result, output_file_name=output_file_name, add_timestamp=add_timestamp)
298
+ content = read_file(file_path)
299
+ return content, file_path
300
+
301
+
302
  def parse_srt(file_path):
303
  """Reads SRT file and returns as dict"""
304
  with open(file_path, 'r', encoding='utf-8') as file:
modules/whisper/base_transcription_pipeline.py CHANGED
@@ -13,9 +13,9 @@ from modules.uvr.music_separator import MusicSeparator
13
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
14
  UVR_MODELS_DIR)
15
  from modules.utils.constants import *
16
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
  from modules.whisper.data_classes import *
20
  from modules.diarize.diarizer import Diarizer
21
  from modules.vad.silero_vad import SileroVAD
@@ -224,14 +224,14 @@ class BaseTranscriptionPipeline(ABC):
224
  )
225
 
226
  file_name, file_ext = os.path.splitext(os.path.basename(file))
227
- subtitle, file_path = self.generate_and_write_file(
228
- file_name=file_name,
229
- transcribed_segments=transcribed_segments,
230
- add_timestamp=add_timestamp,
231
- file_format=file_format,
232
- output_dir=self.output_dir
233
  )
234
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
235
 
236
  total_result = ''
237
  total_time = 0
@@ -291,16 +291,17 @@ class BaseTranscriptionPipeline(ABC):
291
  )
292
  progress(1, desc="Completed!")
293
 
294
- subtitle, result_file_path = self.generate_and_write_file(
295
- file_name="Mic",
296
- transcribed_segments=transcribed_segments,
297
- add_timestamp=add_timestamp,
298
- file_format=file_format,
299
- output_dir=self.output_dir
 
300
  )
301
 
302
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
303
- return [result_str, result_file_path]
304
  except Exception as e:
305
  print(f"Error transcribing file: {e}")
306
  finally:
@@ -351,19 +352,20 @@ class BaseTranscriptionPipeline(ABC):
351
  progress(1, desc="Completed!")
352
 
353
  file_name = safe_filename(yt.title)
354
- subtitle, result_file_path = self.generate_and_write_file(
355
- file_name=file_name,
356
- transcribed_segments=transcribed_segments,
357
- add_timestamp=add_timestamp,
358
- file_format=file_format,
359
- output_dir=self.output_dir
360
  )
 
361
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
362
 
363
  if os.path.exists(audio):
364
  os.remove(audio)
365
 
366
- return [result_str, result_file_path]
367
 
368
  except Exception as e:
369
  print(f"Error transcribing file: {e}")
@@ -384,58 +386,6 @@ class BaseTranscriptionPipeline(ABC):
384
  else:
385
  return list(ctranslate2.get_supported_compute_types("cpu"))
386
 
387
- @staticmethod
388
- def generate_and_write_file(file_name: str,
389
- transcribed_segments: list,
390
- add_timestamp: bool,
391
- file_format: str,
392
- output_dir: str
393
- ) -> str:
394
- """
395
- Writes subtitle file
396
-
397
- Parameters
398
- ----------
399
- file_name: str
400
- Output file name
401
- transcribed_segments: list
402
- Text segments transcribed from audio
403
- add_timestamp: bool
404
- Determines whether to add a timestamp to the end of the filename.
405
- file_format: str
406
- File format to write. Supported formats: [SRT, WebVTT, txt]
407
- output_dir: str
408
- Directory path of the output
409
-
410
- Returns
411
- ----------
412
- content: str
413
- Result of the transcription
414
- output_path: str
415
- output file path
416
- """
417
- if add_timestamp:
418
- timestamp = datetime.now().strftime("%m%d%H%M%S")
419
- output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
420
- else:
421
- output_path = os.path.join(output_dir, f"{file_name}")
422
-
423
- file_format = file_format.strip().lower()
424
- if file_format == "srt":
425
- content = get_srt(transcribed_segments)
426
- output_path += '.srt'
427
-
428
- elif file_format == "webvtt":
429
- content = get_vtt(transcribed_segments)
430
- output_path += '.vtt'
431
-
432
- elif file_format == "txt":
433
- content = get_txt(transcribed_segments)
434
- output_path += '.txt'
435
-
436
- write_file(content, output_path)
437
- return content, output_path
438
-
439
  @staticmethod
440
  def format_time(elapsed_time: float) -> str:
441
  """
 
13
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
14
  UVR_MODELS_DIR)
15
  from modules.utils.constants import *
16
+ from modules.utils.subtitle_manager import *
17
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
19
  from modules.whisper.data_classes import *
20
  from modules.diarize.diarizer import Diarizer
21
  from modules.vad.silero_vad import SileroVAD
 
224
  )
225
 
226
  file_name, file_ext = os.path.splitext(os.path.basename(file))
227
+ subtitle, file_path = generate_file(
228
+ output_dir=self.output_dir,
229
+ output_file_name=file_name,
230
+ output_format=file_format,
231
+ result=transcribed_segments,
232
+ add_timestamp=add_timestamp
233
  )
234
+ files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
235
 
236
  total_result = ''
237
  total_time = 0
 
291
  )
292
  progress(1, desc="Completed!")
293
 
294
+ file_name = "Mic"
295
+ subtitle, file_path = generate_file(
296
+ output_dir=self.output_dir,
297
+ output_file_name=file_name,
298
+ output_format=file_format,
299
+ result=transcribed_segments,
300
+ add_timestamp=add_timestamp
301
  )
302
 
303
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
304
+ return [result_str, file_path]
305
  except Exception as e:
306
  print(f"Error transcribing file: {e}")
307
  finally:
 
352
  progress(1, desc="Completed!")
353
 
354
  file_name = safe_filename(yt.title)
355
+ subtitle, file_path = generate_file(
356
+ output_dir=self.output_dir,
357
+ output_file_name=file_name,
358
+ output_format=file_format,
359
+ result=transcribed_segments,
360
+ add_timestamp=add_timestamp
361
  )
362
+
363
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
364
 
365
  if os.path.exists(audio):
366
  os.remove(audio)
367
 
368
+ return [result_str, file_path]
369
 
370
  except Exception as e:
371
  print(f"Error transcribing file: {e}")
 
386
  else:
387
  return list(ctranslate2.get_supported_compute_types("cpu"))
388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  @staticmethod
390
  def format_time(elapsed_time: float) -> str:
391
  """