File size: 7,850 Bytes
aa0d499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8f59b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import whisperx
import json
import os
import torch
import mimetypes
import shutil

# Define language options
language_options = {
    "Identify": None,
    "English": "en", "Spanish": "es", "Chinese": "zh", "Hindi": "hi", "Arabic": "ar",
    "Portuguese": "pt", "Bengali": "bn", "Russian": "ru", "Japanese": "ja", "Punjabi": "pa",
    "German": "de", "Javanese": "jv", "Wu Chinese": "zh", "Malay": "ms", "Telugu": "te",
    "Vietnamese": "vi", "Korean": "ko", "French": "fr", "Marathi": "mr", "Turkish": "tr"
}

# Available models for transcription
model_options = {
    "Large-v3": "large-v3",
    "Medium": "medium",
    "Small": "small",
    "Base": "base"
}

# Initializes the ModelManager by setting default values and loading a model based on system capabilities (CUDA availability).
class ModelManager:
    def __init__(self):
        self.current_model = None
        self.current_model_name = None
        self.current_device = None
        if torch.cuda.is_available():
            default_device = "cuda"
            default_model = "Large-v3"
        else:
            default_device = "cpu"
            default_model = "Small"
        self.load_model(default_model, default_device)

    def load_model(self, model_choice, device):
        if self.current_model is None or model_choice != self.current_model_name or device != self.current_device:
            print(f"Attempting to load model: {model_choice} on device: {device}")
            compute_type = "float32" if device == "cpu" else "float16"
            self.current_model = whisperx.load_model(model_options[model_choice], device, compute_type=compute_type)
            self.current_model_name = model_choice
            self.current_device = device
        else:
            print(f"Using already loaded model: {self.current_model_name} on device: {self.current_device}")
        return self.current_model

# Validates if the given file path corresponds to a multimedia file (audio or video) by checking MIME types and specific file extensions.
def validate_multimedia_file(file_path):
    file_path = os.path.normpath(file_path)
    mime_type, _ = mimetypes.guess_type(file_path)
    if mime_type and (mime_type.startswith('audio') or mime_type.startswith('video')):
        return file_path
    else:
        if file_path.lower().endswith(('.mp3', '.mp4', '.wav', '.avi', '.mov', '.flv')):
            return file_path
        else:
            raise ValueError("The uploaded file is not a multimedia file. Please upload an appropriate audio or video file.")

# Transcribes a multimedia file
def transcribe(file_obj, device, language, model_choice, model_manager):
    """
    Transcribes a multimedia file using a specified model, handling file operations, 
    language identification, and transcription alignment, and outputs transcription in multiple formats.
    """
    _, ext = os.path.splitext(file_obj.name)
    temp_dir = os.path.join(os.getcwd(), 'Temp')
    
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)
    new_file_path = os.path.join(temp_dir, f'resource{ext}')

    shutil.copy(file_obj.name, new_file_path)

    model = model_manager.load_model(model_choice, device)
    
    validated_file_path = validate_multimedia_file(new_file_path)
    audio = whisperx.load_audio(validated_file_path)

    if language == "Identify":
        result = model.transcribe(audio, batch_size=16)
        language_code = result["language"]
    else:
        language_code = language_options[language]
        result = model.transcribe(audio, language=language_code, batch_size=16)

    model_a, metadata = whisperx.load_align_model(language_code=language_code, device=device)
    try:
        aligned_segments = []
        for segment in result["segments"]:
            aligned_segment = whisperx.align([segment], model_a, metadata, audio, device, return_char_alignments=False)
            aligned_segments.extend(aligned_segment["segments"])
    except Exception as e:
        print(f"Error during alignment: {e}")
        return None

    segments_output = {"segments": aligned_segments}
    json_output = json.dumps(segments_output, ensure_ascii=False, indent=4)
    json_file_path = download_json_interface(json_output, temp_dir)
    txt_path = save_as_text(aligned_segments, temp_dir)
    vtt_path = save_as_vtt(aligned_segments, temp_dir)
    srt_path = save_as_srt(aligned_segments, temp_dir)
    return json_file_path, txt_path, vtt_path, srt_path  

# Saves the transcription text of audio segments to a file in the specified temporary directory and returns the file path.
def save_as_text(segments, temp_dir):
    txt_file_path = os.path.join(temp_dir, 'transcription_output.txt')
    with open(txt_file_path, 'w', encoding='utf-8') as txt_file:
        for segment in segments:
            txt_file.write(f"{segment['text'].strip()}\n")  
    return txt_file_path


def save_as_vtt(segments, temp_dir):
    """
    Saves the transcription text as a .vtt file (Web Video Text Tracks format), 
    which includes timestamps for each segment, in the specified temporary directory and returns the file path.
    """
    vtt_file_path = os.path.join(temp_dir, 'transcription_output.vtt')
    with open(vtt_file_path, 'w', encoding='utf-8') as vtt_file:
        vtt_file.write("WEBVTT\n\n")
        for i, segment in enumerate(segments):
            start = segment['start']
            end = segment['end']
            vtt_file.write(f"{i}\n")
            vtt_file.write(f"{format_time(start)} --> {format_time(end)}\n")
            vtt_file.write(f"{segment['text'].strip()}\n\n")  
    return vtt_file_path

def download_json_interface(json_data, temp_dir):
    """
    Reads JSON-formatted transcription data, modifies and re-saves it in a neatly 
    formatted JSON file in the specified temporary directory, and returns the file path.
    """
    json_file_path = os.path.join(temp_dir, 'transcription_output.json')
    with open(json_file_path, 'w', encoding='utf-8') as json_file:
        json_data = json.loads(json_data)  
        for segment in json_data['segments']:
            segment['text'] = segment['text'].strip()
        json_data = json.dumps(json_data, ensure_ascii=False, indent=4)  
        json_file.write(json_data)
    return json_file_path


def save_as_srt(segments, temp_dir):
    """
    Saves the transcription text as an .srt file (SubRip Subtitle format), 
    which includes numbered entries with start and end times and corresponding text for each segment, 
    in the specified temporary directory and returns the file path.
    """
    srt_file_path = os.path.join(temp_dir, 'transcription_output.srt')
    with open(srt_file_path, 'w', encoding='utf-8') as srt_file:
        for i, segment in enumerate(segments):
            start = segment['start']
            end = segment['end']
            srt_file.write(f"{i+1}\n")
            srt_file.write(f"{format_time_srt(start)} --> {format_time_srt(end)}\n")
            srt_file.write(f"{segment['text'].strip()}\n\n")  
    return srt_file_path

# Converts a time value in seconds to a formatted string in the "hours:minutes:seconds,milliseconds" format, used for timestamps in VTT files.
def format_time(time_in_seconds):
    hours = int(time_in_seconds // 3600)
    minutes = int((time_in_seconds % 3600) // 60)
    seconds = time_in_seconds % 60
    return f"{hours:02}:{minutes:02}:{seconds:06.3f}"

# Converts a time value in seconds to a formatted string suitable for SRT files, specifically in the "hours:minutes:seconds,milliseconds" format.
def format_time_srt(time_in_seconds):
    hours = int(time_in_seconds // 3600)
    minutes = int((time_in_seconds % 3600) // 60)
    seconds = int(time_in_seconds % 60)
    milliseconds = int((time_in_seconds - int(time_in_seconds)) * 1000)
    return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"