File size: 3,844 Bytes
5fa5566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from pathlib import Path

import whisperx
import whisper # only for detect language

import whisper_utils
import subtitle_utils
from utils import time_task

def transcribe_audio(model: whisperx.asr.WhisperModel, audio_path: Path, srt_path: Path, lang: str = None, device: str = "cpu", batch_size: int = 4):
    audio = whisperx.load_audio(file=audio_path.as_posix(), sr=model.model.feature_extractor.sampling_rate)
        
    # Define the progress callback function
    def progress_callback(state, current: int = None, total: int = None):
        args = state, current, total
        args = [arg for arg in args if arg is not None]

        if len(args) == 1:
            state = args[0]
        if len(args) > 1:
            total = args[-1]
            current = args[-2]
            state = None
        if len(args) > 2:
            state = args[-3]

        try:
            if state is None:
                state = "WhisperX"  
            elif type(state) == 'String' or type(state) == int:
                state = state
            else:
                state = state.value
        except:
            state = "WhisperX"

        print('\r                                                            \r' + state + ((': ' + str(round(current/total*100)) + '%') if current and total else '') + ((' [' + str(current) + '/' + str(total) + ']') if current and total else ''), end=' ', flush=True)

    # Transcribe
    with time_task("Running WhisperX transcription engine...", end='\n'):
        transcribe = model.transcribe(audio=audio, language=lang, batch_size=batch_size, on_progress=progress_callback)

    # Align if possible
    if lang in whisperx.alignment.DEFAULT_ALIGN_MODELS_HF or lang in whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH:
        with time_task(message_start="Running alignment...", end='\n'):
            try:
                model_a, metadata = whisperx.load_align_model(language_code=lang, device="cuda")
                transcribe = whisperx.align(transcript=transcribe["segments"], model=model_a, align_model_metadata=metadata, audio=audio, device="cuda", return_char_alignments=True, on_progress=progress_callback)
            except Exception:
                model_a, metadata = whisperx.load_align_model(language_code=lang, device="cpu")  # force load on cpu due errors on gpu
                transcribe = whisperx.align(transcript=transcribe["segments"], model=model_a, align_model_metadata=metadata, audio=audio, device="cpu", return_char_alignments=True, on_progress=progress_callback)
    else:
        print(f"Language {lang} not suported for alignment. Skipping this step")

    # Format subtitles
    segments = subtitle_utils.format_segments(transcribe['segments'])

    # Save the subtitle file
    subtitle_utils.SaveSegmentsToSrt(segments, srt_path)

    return transcribe


def detect_language(model: whisperx.asr.WhisperModel, audio_path: Path):
    try:
        if os.getenv("COLAB_RELEASE_TAG"):
            raise Exception("Method invalid for Google Colab") 
        audio = whisperx.load_audio(audio_path.as_posix(), model.model.feature_extractor.sampling_rate)
        audio = whisper.pad_or_trim(audio, model.model.feature_extractor.n_samples)
        mel = whisperx.asr.log_mel_spectrogram(audio, n_mels=model.model.model.n_mels)
        encoder_output = model.model.encode(mel)
        results = model.model.model.detect_language(encoder_output)
        language_token, language_probability = results[0][0]
        return language_token[2:-2]
    except:
        print("using whisper base model for detection: ", end='')
        whisper_model = whisper.load_model("base", device="cpu", in_memory=True)
        return whisper_utils.detect_language(model=whisper_model, audio_path=audio_path)