LukeJacob2023 commited on
Commit
aa0d499
·
verified ·
1 Parent(s): 80f134b

Update src/transcription_utils.py

Browse files
Files changed (1) hide show
  1. src/transcription_utils.py +175 -175
src/transcription_utils.py CHANGED
@@ -1,176 +1,176 @@
1
- import whisperx
2
- import json
3
- import os
4
- import torch
5
- import mimetypes
6
- import shutil
7
-
8
- # Define language options
9
- language_options = {
10
- "Identify": None,
11
- "English": "en", "Spanish": "es", "Chinese": "zh", "Hindi": "hi", "Arabic": "ar",
12
- "Portuguese": "pt", "Bengali": "bn", "Russian": "ru", "Japanese": "ja", "Punjabi": "pa",
13
- "German": "de", "Javanese": "jv", "Wu Chinese": "zh", "Malay": "ms", "Telugu": "te",
14
- "Vietnamese": "vi", "Korean": "ko", "French": "fr", "Marathi": "mr", "Turkish": "tr"
15
- }
16
-
17
- # Available models for transcription
18
- model_options = {
19
- "Large-v2": "large-v2",
20
- "Medium": "medium",
21
- "Small": "small",
22
- "Base": "base"
23
- }
24
-
25
- # Initializes the ModelManager by setting default values and loading a model based on system capabilities (CUDA availability).
26
- class ModelManager:
27
- def __init__(self):
28
- self.current_model = None
29
- self.current_model_name = None
30
- self.current_device = None
31
- if torch.cuda.is_available():
32
- default_device = "cuda"
33
- default_model = "Large-v2"
34
- else:
35
- default_device = "cpu"
36
- default_model = "Medium"
37
- self.load_model(default_model, default_device)
38
-
39
- def load_model(self, model_choice, device):
40
- if self.current_model is None or model_choice != self.current_model_name or device != self.current_device:
41
- print(f"Attempting to load model: {model_choice} on device: {device}")
42
- compute_type = "float32" if device == "cpu" else "float16"
43
- self.current_model = whisperx.load_model(model_options[model_choice], device, compute_type=compute_type)
44
- self.current_model_name = model_choice
45
- self.current_device = device
46
- else:
47
- print(f"Using already loaded model: {self.current_model_name} on device: {self.current_device}")
48
- return self.current_model
49
-
50
- # Validates if the given file path corresponds to a multimedia file (audio or video) by checking MIME types and specific file extensions.
51
- def validate_multimedia_file(file_path):
52
- file_path = os.path.normpath(file_path)
53
- mime_type, _ = mimetypes.guess_type(file_path)
54
- if mime_type and (mime_type.startswith('audio') or mime_type.startswith('video')):
55
- return file_path
56
- else:
57
- if file_path.lower().endswith(('.mp3', '.mp4', '.wav', '.avi', '.mov', '.flv')):
58
- return file_path
59
- else:
60
- raise ValueError("The uploaded file is not a multimedia file. Please upload an appropriate audio or video file.")
61
-
62
- # Transcribes a multimedia file
63
- def transcribe(file_obj, device, language, model_choice, model_manager):
64
- """
65
- Transcribes a multimedia file using a specified model, handling file operations,
66
- language identification, and transcription alignment, and outputs transcription in multiple formats.
67
- """
68
- _, ext = os.path.splitext(file_obj.name)
69
- temp_dir = os.path.join(os.getcwd(), 'Temp')
70
-
71
- if not os.path.exists(temp_dir):
72
- os.makedirs(temp_dir)
73
- new_file_path = os.path.join(temp_dir, f'resource{ext}')
74
-
75
- shutil.copy(file_obj.name, new_file_path)
76
-
77
- model = model_manager.load_model(model_choice, device)
78
-
79
- validated_file_path = validate_multimedia_file(new_file_path)
80
- audio = whisperx.load_audio(validated_file_path)
81
-
82
- if language == "Identify":
83
- result = model.transcribe(audio)
84
- language_code = result["language"]
85
- else:
86
- language_code = language_options[language]
87
- result = model.transcribe(audio, language=language_code)
88
-
89
- model_a, metadata = whisperx.load_align_model(language_code=language_code, device=device)
90
- try:
91
- aligned_segments = []
92
- for segment in result["segments"]:
93
- aligned_segment = whisperx.align([segment], model_a, metadata, audio, device, return_char_alignments=False)
94
- aligned_segments.extend(aligned_segment["segments"])
95
- except Exception as e:
96
- print(f"Error during alignment: {e}")
97
- return None
98
-
99
- segments_output = {"segments": aligned_segments}
100
- json_output = json.dumps(segments_output, ensure_ascii=False, indent=4)
101
- json_file_path = download_json_interface(json_output, temp_dir)
102
- txt_path = save_as_text(aligned_segments, temp_dir)
103
- vtt_path = save_as_vtt(aligned_segments, temp_dir)
104
- srt_path = save_as_srt(aligned_segments, temp_dir)
105
- return json_file_path, txt_path, vtt_path, srt_path
106
-
107
- # Saves the transcription text of audio segments to a file in the specified temporary directory and returns the file path.
108
- def save_as_text(segments, temp_dir):
109
- txt_file_path = os.path.join(temp_dir, 'transcription_output.txt')
110
- with open(txt_file_path, 'w', encoding='utf-8') as txt_file:
111
- for segment in segments:
112
- txt_file.write(f"{segment['text'].strip()}\n")
113
- return txt_file_path
114
-
115
-
116
- def save_as_vtt(segments, temp_dir):
117
- """
118
- Saves the transcription text as a .vtt file (Web Video Text Tracks format),
119
- which includes timestamps for each segment, in the specified temporary directory and returns the file path.
120
- """
121
- vtt_file_path = os.path.join(temp_dir, 'transcription_output.vtt')
122
- with open(vtt_file_path, 'w', encoding='utf-8') as vtt_file:
123
- vtt_file.write("WEBVTT\n\n")
124
- for i, segment in enumerate(segments):
125
- start = segment['start']
126
- end = segment['end']
127
- vtt_file.write(f"{i}\n")
128
- vtt_file.write(f"{format_time(start)} --> {format_time(end)}\n")
129
- vtt_file.write(f"{segment['text'].strip()}\n\n")
130
- return vtt_file_path
131
-
132
- def download_json_interface(json_data, temp_dir):
133
- """
134
- Reads JSON-formatted transcription data, modifies and re-saves it in a neatly
135
- formatted JSON file in the specified temporary directory, and returns the file path.
136
- """
137
- json_file_path = os.path.join(temp_dir, 'transcription_output.json')
138
- with open(json_file_path, 'w', encoding='utf-8') as json_file:
139
- json_data = json.loads(json_data)
140
- for segment in json_data['segments']:
141
- segment['text'] = segment['text'].strip()
142
- json_data = json.dumps(json_data, ensure_ascii=False, indent=4)
143
- json_file.write(json_data)
144
- return json_file_path
145
-
146
-
147
- def save_as_srt(segments, temp_dir):
148
- """
149
- Saves the transcription text as an .srt file (SubRip Subtitle format),
150
- which includes numbered entries with start and end times and corresponding text for each segment,
151
- in the specified temporary directory and returns the file path.
152
- """
153
- srt_file_path = os.path.join(temp_dir, 'transcription_output.srt')
154
- with open(srt_file_path, 'w', encoding='utf-8') as srt_file:
155
- for i, segment in enumerate(segments):
156
- start = segment['start']
157
- end = segment['end']
158
- srt_file.write(f"{i+1}\n")
159
- srt_file.write(f"{format_time_srt(start)} --> {format_time_srt(end)}\n")
160
- srt_file.write(f"{segment['text'].strip()}\n\n")
161
- return srt_file_path
162
-
163
- # Converts a time value in seconds to a formatted string in the "hours:minutes:seconds,milliseconds" format, used for timestamps in VTT files.
164
- def format_time(time_in_seconds):
165
- hours = int(time_in_seconds // 3600)
166
- minutes = int((time_in_seconds % 3600) // 60)
167
- seconds = time_in_seconds % 60
168
- return f"{hours:02}:{minutes:02}:{seconds:06.3f}"
169
-
170
- # Converts a time value in seconds to a formatted string suitable for SRT files, specifically in the "hours:minutes:seconds,milliseconds" format.
171
- def format_time_srt(time_in_seconds):
172
- hours = int(time_in_seconds // 3600)
173
- minutes = int((time_in_seconds % 3600) // 60)
174
- seconds = int(time_in_seconds % 60)
175
- milliseconds = int((time_in_seconds - int(time_in_seconds)) * 1000)
176
  return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
 
1
+ import whisperx
2
+ import json
3
+ import os
4
+ import torch
5
+ import mimetypes
6
+ import shutil
7
+
8
+ # Define language options
9
+ language_options = {
10
+ "Identify": None,
11
+ "English": "en", "Spanish": "es", "Chinese": "zh", "Hindi": "hi", "Arabic": "ar",
12
+ "Portuguese": "pt", "Bengali": "bn", "Russian": "ru", "Japanese": "ja", "Punjabi": "pa",
13
+ "German": "de", "Javanese": "jv", "Wu Chinese": "zh", "Malay": "ms", "Telugu": "te",
14
+ "Vietnamese": "vi", "Korean": "ko", "French": "fr", "Marathi": "mr", "Turkish": "tr"
15
+ }
16
+
17
+ # Available models for transcription
18
+ model_options = {
19
+ "Large-v3": "large-v3",
20
+ "Medium": "medium",
21
+ "Small": "small",
22
+ "Base": "base"
23
+ }
24
+
25
+ # Initializes the ModelManager by setting default values and loading a model based on system capabilities (CUDA availability).
26
+ class ModelManager:
27
+ def __init__(self):
28
+ self.current_model = None
29
+ self.current_model_name = None
30
+ self.current_device = None
31
+ if torch.cuda.is_available():
32
+ default_device = "cuda"
33
+ default_model = "Large-v3"
34
+ else:
35
+ default_device = "cpu"
36
+ default_model = "Small"
37
+ self.load_model(default_model, default_device)
38
+
39
+ def load_model(self, model_choice, device):
40
+ if self.current_model is None or model_choice != self.current_model_name or device != self.current_device:
41
+ print(f"Attempting to load model: {model_choice} on device: {device}")
42
+ compute_type = "float32" if device == "cpu" else "float16"
43
+ self.current_model = whisperx.load_model(model_options[model_choice], device, compute_type=compute_type)
44
+ self.current_model_name = model_choice
45
+ self.current_device = device
46
+ else:
47
+ print(f"Using already loaded model: {self.current_model_name} on device: {self.current_device}")
48
+ return self.current_model
49
+
50
+ # Validates if the given file path corresponds to a multimedia file (audio or video) by checking MIME types and specific file extensions.
51
+ def validate_multimedia_file(file_path):
52
+ file_path = os.path.normpath(file_path)
53
+ mime_type, _ = mimetypes.guess_type(file_path)
54
+ if mime_type and (mime_type.startswith('audio') or mime_type.startswith('video')):
55
+ return file_path
56
+ else:
57
+ if file_path.lower().endswith(('.mp3', '.mp4', '.wav', '.avi', '.mov', '.flv')):
58
+ return file_path
59
+ else:
60
+ raise ValueError("The uploaded file is not a multimedia file. Please upload an appropriate audio or video file.")
61
+
62
+ # Transcribes a multimedia file
63
+ def transcribe(file_obj, device, language, model_choice, model_manager):
64
+ """
65
+ Transcribes a multimedia file using a specified model, handling file operations,
66
+ language identification, and transcription alignment, and outputs transcription in multiple formats.
67
+ """
68
+ _, ext = os.path.splitext(file_obj.name)
69
+ temp_dir = os.path.join(os.getcwd(), 'Temp')
70
+
71
+ if not os.path.exists(temp_dir):
72
+ os.makedirs(temp_dir)
73
+ new_file_path = os.path.join(temp_dir, f'resource{ext}')
74
+
75
+ shutil.copy(file_obj.name, new_file_path)
76
+
77
+ model = model_manager.load_model(model_choice, device)
78
+
79
+ validated_file_path = validate_multimedia_file(new_file_path)
80
+ audio = whisperx.load_audio(validated_file_path)
81
+
82
+ if language == "Identify":
83
+ result = model.transcribe(audio, batch_size=16)
84
+ language_code = result["language"]
85
+ else:
86
+ language_code = language_options[language]
87
+ result = model.transcribe(audio, language=language_code, batch_size=16)
88
+
89
+ model_a, metadata = whisperx.load_align_model(language_code=language_code, device=device)
90
+ try:
91
+ aligned_segments = []
92
+ for segment in result["segments"]:
93
+ aligned_segment = whisperx.align([segment], model_a, metadata, audio, device, return_char_alignments=False)
94
+ aligned_segments.extend(aligned_segment["segments"])
95
+ except Exception as e:
96
+ print(f"Error during alignment: {e}")
97
+ return None
98
+
99
+ segments_output = {"segments": aligned_segments}
100
+ json_output = json.dumps(segments_output, ensure_ascii=False, indent=4)
101
+ json_file_path = download_json_interface(json_output, temp_dir)
102
+ txt_path = save_as_text(aligned_segments, temp_dir)
103
+ vtt_path = save_as_vtt(aligned_segments, temp_dir)
104
+ srt_path = save_as_srt(aligned_segments, temp_dir)
105
+ return json_file_path, txt_path, vtt_path, srt_path
106
+
107
+ # Saves the transcription text of audio segments to a file in the specified temporary directory and returns the file path.
108
+ def save_as_text(segments, temp_dir):
109
+ txt_file_path = os.path.join(temp_dir, 'transcription_output.txt')
110
+ with open(txt_file_path, 'w', encoding='utf-8') as txt_file:
111
+ for segment in segments:
112
+ txt_file.write(f"{segment['text'].strip()}\n")
113
+ return txt_file_path
114
+
115
+
116
+ def save_as_vtt(segments, temp_dir):
117
+ """
118
+ Saves the transcription text as a .vtt file (Web Video Text Tracks format),
119
+ which includes timestamps for each segment, in the specified temporary directory and returns the file path.
120
+ """
121
+ vtt_file_path = os.path.join(temp_dir, 'transcription_output.vtt')
122
+ with open(vtt_file_path, 'w', encoding='utf-8') as vtt_file:
123
+ vtt_file.write("WEBVTT\n\n")
124
+ for i, segment in enumerate(segments):
125
+ start = segment['start']
126
+ end = segment['end']
127
+ vtt_file.write(f"{i}\n")
128
+ vtt_file.write(f"{format_time(start)} --> {format_time(end)}\n")
129
+ vtt_file.write(f"{segment['text'].strip()}\n\n")
130
+ return vtt_file_path
131
+
132
+ def download_json_interface(json_data, temp_dir):
133
+ """
134
+ Reads JSON-formatted transcription data, modifies and re-saves it in a neatly
135
+ formatted JSON file in the specified temporary directory, and returns the file path.
136
+ """
137
+ json_file_path = os.path.join(temp_dir, 'transcription_output.json')
138
+ with open(json_file_path, 'w', encoding='utf-8') as json_file:
139
+ json_data = json.loads(json_data)
140
+ for segment in json_data['segments']:
141
+ segment['text'] = segment['text'].strip()
142
+ json_data = json.dumps(json_data, ensure_ascii=False, indent=4)
143
+ json_file.write(json_data)
144
+ return json_file_path
145
+
146
+
147
+ def save_as_srt(segments, temp_dir):
148
+ """
149
+ Saves the transcription text as an .srt file (SubRip Subtitle format),
150
+ which includes numbered entries with start and end times and corresponding text for each segment,
151
+ in the specified temporary directory and returns the file path.
152
+ """
153
+ srt_file_path = os.path.join(temp_dir, 'transcription_output.srt')
154
+ with open(srt_file_path, 'w', encoding='utf-8') as srt_file:
155
+ for i, segment in enumerate(segments):
156
+ start = segment['start']
157
+ end = segment['end']
158
+ srt_file.write(f"{i+1}\n")
159
+ srt_file.write(f"{format_time_srt(start)} --> {format_time_srt(end)}\n")
160
+ srt_file.write(f"{segment['text'].strip()}\n\n")
161
+ return srt_file_path
162
+
163
+ # Converts a time value in seconds to a formatted string in the "hours:minutes:seconds,milliseconds" format, used for timestamps in VTT files.
164
+ def format_time(time_in_seconds):
165
+ hours = int(time_in_seconds // 3600)
166
+ minutes = int((time_in_seconds % 3600) // 60)
167
+ seconds = time_in_seconds % 60
168
+ return f"{hours:02}:{minutes:02}:{seconds:06.3f}"
169
+
170
+ # Converts a time value in seconds to a formatted string suitable for SRT files, specifically in the "hours:minutes:seconds,milliseconds" format.
171
+ def format_time_srt(time_in_seconds):
172
+ hours = int(time_in_seconds // 3600)
173
+ minutes = int((time_in_seconds % 3600) // 60)
174
+ seconds = int(time_in_seconds % 60)
175
+ milliseconds = int((time_in_seconds - int(time_in_seconds)) * 1000)
176
  return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"