Adding CLI
Browse filesThis is similar to the CLI in Whisper, but it also supports
downloading URLs (also playlists), and using a VAD.
- app.py +64 -49
 - cli.py +108 -0
 - src/download.py +14 -7
 - src/vad.py +3 -1
 
    	
        app.py
    CHANGED
    
    | 
         @@ -53,7 +53,7 @@ class WhisperTranscriber: 
     | 
|
| 53 | 
         
             
                    self.inputAudioMaxDuration = inputAudioMaxDuration
         
     | 
| 54 | 
         
             
                    self.deleteUploadedFiles = deleteUploadedFiles
         
     | 
| 55 | 
         | 
| 56 | 
         
            -
                def  
     | 
| 57 | 
         
             
                    try:
         
     | 
| 58 | 
         
             
                        source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
         
     | 
| 59 | 
         | 
| 
         @@ -67,54 +67,14 @@ class WhisperTranscriber: 
     | 
|
| 67 | 
         
             
                                model = whisper.load_model(selectedModel)
         
     | 
| 68 | 
         
             
                                self.model_cache[selectedModel] = model
         
     | 
| 69 | 
         | 
| 70 | 
         
            -
                            #  
     | 
| 71 | 
         
            -
                             
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                            #  
     | 
| 74 | 
         
            -
                            if (vad == 'silero-vad'):
         
     | 
| 75 | 
         
            -
                                # Use Silero VAD and include gaps
         
     | 
| 76 | 
         
            -
                                if (self.vad_model is None):
         
     | 
| 77 | 
         
            -
                                    self.vad_model = VadSileroTranscription()
         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
                                process_gaps = VadSileroTranscription(transcribe_non_speech = True, 
         
     | 
| 80 | 
         
            -
                                                max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
         
     | 
| 81 | 
         
            -
                                                segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
         
     | 
| 82 | 
         
            -
                                result = process_gaps.transcribe(source, whisperCallable)
         
     | 
| 83 | 
         
            -
                            elif (vad == 'silero-vad-skip-gaps'):
         
     | 
| 84 | 
         
            -
                                # Use Silero VAD 
         
     | 
| 85 | 
         
            -
                                if (self.vad_model is None):
         
     | 
| 86 | 
         
            -
                                    self.vad_model = VadSileroTranscription()
         
     | 
| 87 | 
         
            -
                                    
         
     | 
| 88 | 
         
            -
                                skip_gaps = VadSileroTranscription(transcribe_non_speech = False, 
         
     | 
| 89 | 
         
            -
                                                max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
         
     | 
| 90 | 
         
            -
                                                segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
         
     | 
| 91 | 
         
            -
                                result = skip_gaps.transcribe(source, whisperCallable)
         
     | 
| 92 | 
         
            -
                            elif (vad == 'periodic-vad'):
         
     | 
| 93 | 
         
            -
                                # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
         
     | 
| 94 | 
         
            -
                                # it may create a break in the middle of a sentence, causing some artifacts.
         
     | 
| 95 | 
         
            -
                                periodic_vad = VadPeriodicTranscription(periodic_duration=vadMaxMergeSize)
         
     | 
| 96 | 
         
            -
                                result = periodic_vad.transcribe(source, whisperCallable)
         
     | 
| 97 | 
         
            -
                            else:
         
     | 
| 98 | 
         
            -
                                # Default VAD
         
     | 
| 99 | 
         
            -
                                result = whisperCallable(source)
         
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                            text = result["text"]
         
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                            language = result["language"]
         
     | 
| 104 | 
         
            -
                            languageMaxLineWidth = self.__get_max_line_width(language)
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                            print("Max line width " + str(languageMaxLineWidth))
         
     | 
| 107 | 
         
            -
                            vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
         
     | 
| 108 | 
         
            -
                            srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
         
     | 
| 109 | 
         
            -
             
     | 
| 110 | 
         
            -
                            # Files that can be downloaded
         
     | 
| 111 | 
         
             
                            downloadDirectory = tempfile.mkdtemp()
         
     | 
| 
         | 
|
| 112 | 
         
             
                            filePrefix = slugify(sourceName, allow_unicode=True)
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                            download = []
         
     | 
| 115 | 
         
            -
                            download.append(self.__create_file(srt, downloadDirectory, filePrefix + "-subs.srt"));
         
     | 
| 116 | 
         
            -
                            download.append(self.__create_file(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
         
     | 
| 117 | 
         
            -
                            download.append(self.__create_file(text, downloadDirectory, filePrefix + "-transcript.txt"));
         
     | 
| 118 | 
         | 
| 119 | 
         
             
                            return download, text, vtt
         
     | 
| 120 | 
         | 
| 
         @@ -127,13 +87,68 @@ class WhisperTranscriber: 
     | 
|
| 127 | 
         
             
                    except ExceededMaximumDuration as e:
         
     | 
| 128 | 
         
             
                        return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
         
     | 
| 129 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         
             
                def clear_cache(self):
         
     | 
| 131 | 
         
             
                    self.model_cache = dict()
         
     | 
| 
         | 
|
| 132 | 
         | 
| 133 | 
         
             
                def __get_source(self, urlData, uploadFile, microphoneData):
         
     | 
| 134 | 
         
             
                    if urlData:
         
     | 
| 135 | 
         
             
                        # Download from YouTube
         
     | 
| 136 | 
         
            -
                        source = download_url(urlData, self.inputAudioMaxDuration)
         
     | 
| 137 | 
         
             
                    else:
         
     | 
| 138 | 
         
             
                        # File input
         
     | 
| 139 | 
         
             
                        source = uploadFile if uploadFile is not None else microphoneData
         
     | 
| 
         @@ -194,7 +209,7 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None): 
     | 
|
| 194 | 
         | 
| 195 | 
         
             
                ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
         
     | 
| 196 | 
         | 
| 197 | 
         
            -
                demo = gr.Interface(fn=ui. 
     | 
| 198 | 
         
             
                    gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
         
     | 
| 199 | 
         
             
                    gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
         
     | 
| 200 | 
         
             
                    gr.Text(label="URL (YouTube, etc.)"),
         
     | 
| 
         | 
|
| 53 | 
         
             
                    self.inputAudioMaxDuration = inputAudioMaxDuration
         
     | 
| 54 | 
         
             
                    self.deleteUploadedFiles = deleteUploadedFiles
         
     | 
| 55 | 
         | 
| 56 | 
         
            +
                def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
         
     | 
| 57 | 
         
             
                    try:
         
     | 
| 58 | 
         
             
                        source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
         
     | 
| 59 | 
         | 
| 
         | 
|
| 67 | 
         
             
                                model = whisper.load_model(selectedModel)
         
     | 
| 68 | 
         
             
                                self.model_cache[selectedModel] = model
         
     | 
| 69 | 
         | 
| 70 | 
         
            +
                            # Execute whisper
         
     | 
| 71 | 
         
            +
                            result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                            # Write result
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 74 | 
         
             
                            downloadDirectory = tempfile.mkdtemp()
         
     | 
| 75 | 
         
            +
                            
         
     | 
| 76 | 
         
             
                            filePrefix = slugify(sourceName, allow_unicode=True)
         
     | 
| 77 | 
         
            +
                            download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 78 | 
         | 
| 79 | 
         
             
                            return download, text, vtt
         
     | 
| 80 | 
         | 
| 
         | 
|
| 87 | 
         
             
                    except ExceededMaximumDuration as e:
         
     | 
| 88 | 
         
             
                        return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
         
     | 
| 89 | 
         | 
| 90 | 
         
            +
                def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None, 
         
     | 
| 91 | 
         
            +
                                    vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, **decodeOptions: dict):
         
     | 
| 92 | 
         
            +
                    # Callable for processing an audio file
         
     | 
| 93 | 
         
            +
                    whisperCallable = lambda audio : model.transcribe(audio, language=language, task=task, **decodeOptions)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # The results
         
     | 
| 96 | 
         
            +
                    if (vad == 'silero-vad'):
         
     | 
| 97 | 
         
            +
                        # Use Silero VAD and include gaps
         
     | 
| 98 | 
         
            +
                        if (self.vad_model is None):
         
     | 
| 99 | 
         
            +
                            self.vad_model = VadSileroTranscription()
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                        process_gaps = VadSileroTranscription(transcribe_non_speech = True, 
         
     | 
| 102 | 
         
            +
                                        max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
         
     | 
| 103 | 
         
            +
                                        segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
         
     | 
| 104 | 
         
            +
                        result = process_gaps.transcribe(audio_path, whisperCallable)
         
     | 
| 105 | 
         
            +
                    elif (vad == 'silero-vad-skip-gaps'):
         
     | 
| 106 | 
         
            +
                        # Use Silero VAD 
         
     | 
| 107 | 
         
            +
                        if (self.vad_model is None):
         
     | 
| 108 | 
         
            +
                            self.vad_model = VadSileroTranscription()
         
     | 
| 109 | 
         
            +
                            
         
     | 
| 110 | 
         
            +
                        skip_gaps = VadSileroTranscription(transcribe_non_speech = False, 
         
     | 
| 111 | 
         
            +
                                        max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
         
     | 
| 112 | 
         
            +
                                        segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
         
     | 
| 113 | 
         
            +
                        result = skip_gaps.transcribe(audio_path, whisperCallable)
         
     | 
| 114 | 
         
            +
                    elif (vad == 'periodic-vad'):
         
     | 
| 115 | 
         
            +
                        # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
         
     | 
| 116 | 
         
            +
                        # it may create a break in the middle of a sentence, causing some artifacts.
         
     | 
| 117 | 
         
            +
                        periodic_vad = VadPeriodicTranscription(periodic_duration=vadMaxMergeSize)
         
     | 
| 118 | 
         
            +
                        result = periodic_vad.transcribe(audio_path, whisperCallable)
         
     | 
| 119 | 
         
            +
                    else:
         
     | 
| 120 | 
         
            +
                        # Default VAD
         
     | 
| 121 | 
         
            +
                        result = whisperCallable(audio_path)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    return result
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                def write_result(self, result: dict, source_name: str, output_dir: str):
         
     | 
| 126 | 
         
            +
                    if not os.path.exists(output_dir):
         
     | 
| 127 | 
         
            +
                        os.makedirs(output_dir)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    text = result["text"]
         
     | 
| 130 | 
         
            +
                    language = result["language"]
         
     | 
| 131 | 
         
            +
                    languageMaxLineWidth = self.__get_max_line_width(language)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    print("Max line width " + str(languageMaxLineWidth))
         
     | 
| 134 | 
         
            +
                    vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
         
     | 
| 135 | 
         
            +
                    srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    output_files = []
         
     | 
| 138 | 
         
            +
                    output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
         
     | 
| 139 | 
         
            +
                    output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
         
     | 
| 140 | 
         
            +
                    output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    return output_files, text, vtt
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
             
                def clear_cache(self):
         
     | 
| 145 | 
         
             
                    self.model_cache = dict()
         
     | 
| 146 | 
         
            +
                    self.vad_model = None
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                def __get_source(self, urlData, uploadFile, microphoneData):
         
     | 
| 149 | 
         
             
                    if urlData:
         
     | 
| 150 | 
         
             
                        # Download from YouTube
         
     | 
| 151 | 
         
            +
                        source = download_url(urlData, self.inputAudioMaxDuration)[0]
         
     | 
| 152 | 
         
             
                    else:
         
     | 
| 153 | 
         
             
                        # File input
         
     | 
| 154 | 
         
             
                        source = uploadFile if uploadFile is not None else microphoneData
         
     | 
| 
         | 
|
| 209 | 
         | 
| 210 | 
         
             
                ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
         
     | 
| 211 | 
         | 
| 212 | 
         
            +
                demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
         
     | 
| 213 | 
         
             
                    gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
         
     | 
| 214 | 
         
             
                    gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
         
     | 
| 215 | 
         
             
                    gr.Text(label="URL (YouTube, etc.)"),
         
     | 
    	
        cli.py
    ADDED
    
    | 
         @@ -0,0 +1,108 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import pathlib
         
     | 
| 4 | 
         
            +
            from urllib.parse import urlparse
         
     | 
| 5 | 
         
            +
            import warnings
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import whisper
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from app import LANGUAGES, WhisperTranscriber
         
     | 
| 12 | 
         
            +
            from src.download import download_url
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from src.utils import optional_float, optional_int, str2bool
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def cli():
         
     | 
| 18 | 
         
            +
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         
     | 
| 19 | 
         
            +
                parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
         
     | 
| 20 | 
         
            +
                parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large"], help="name of the Whisper model to use")
         
     | 
| 21 | 
         
            +
                parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
         
     | 
| 22 | 
         
            +
                parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
         
     | 
| 23 | 
         
            +
                parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
         
     | 
| 24 | 
         
            +
                parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
         
     | 
| 27 | 
         
            +
                parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
         
     | 
| 30 | 
         
            +
                parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
         
     | 
| 31 | 
         
            +
                parser.add_argument("--vad_max_merge_size", type=optional_float, default=150, help="The maximum size (in seconds) of a voice segment")
         
     | 
| 32 | 
         
            +
                parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
         
     | 
| 35 | 
         
            +
                parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
         
     | 
| 36 | 
         
            +
                parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
         
     | 
| 37 | 
         
            +
                parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
         
     | 
| 38 | 
         
            +
                parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
         
     | 
| 41 | 
         
            +
                parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
         
     | 
| 42 | 
         
            +
                parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
         
     | 
| 43 | 
         
            +
                parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
         
     | 
| 46 | 
         
            +
                parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
         
     | 
| 47 | 
         
            +
                parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
         
     | 
| 48 | 
         
            +
                parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                args = parser.parse_args().__dict__
         
     | 
| 51 | 
         
            +
                model_name: str = args.pop("model")
         
     | 
| 52 | 
         
            +
                model_dir: str = args.pop("model_dir")
         
     | 
| 53 | 
         
            +
                output_dir: str = args.pop("output_dir")
         
     | 
| 54 | 
         
            +
                device: str = args.pop("device")
         
     | 
| 55 | 
         
            +
                os.makedirs(output_dir, exist_ok=True)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
         
     | 
| 58 | 
         
            +
                    warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
         
     | 
| 59 | 
         
            +
                    args["language"] = "en"
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                temperature = args.pop("temperature")
         
     | 
| 62 | 
         
            +
                temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
         
     | 
| 63 | 
         
            +
                if temperature_increment_on_fallback is not None:
         
     | 
| 64 | 
         
            +
                    temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
         
     | 
| 65 | 
         
            +
                else:
         
     | 
| 66 | 
         
            +
                    temperature = [temperature]
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                vad = args.pop("vad")
         
     | 
| 69 | 
         
            +
                vad_merge_window = args.pop("vad_merge_window")
         
     | 
| 70 | 
         
            +
                vad_max_merge_size = args.pop("vad_max_merge_size")
         
     | 
| 71 | 
         
            +
                vad_padding = args.pop("vad_padding")
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                model = whisper.load_model(model_name, device=device, download_root=model_dir)
         
     | 
| 74 | 
         
            +
                transcriber = WhisperTranscriber(deleteUploadedFiles=False)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                for audio_path in args.pop("audio"):
         
     | 
| 77 | 
         
            +
                    sources = []
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    # Detect URL and download the audio
         
     | 
| 80 | 
         
            +
                    if (uri_validator(audio_path)):
         
     | 
| 81 | 
         
            +
                        # Download from YouTube/URL directly
         
     | 
| 82 | 
         
            +
                        for source_path in  download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
         
     | 
| 83 | 
         
            +
                            source_name = os.path.basename(source_path)
         
     | 
| 84 | 
         
            +
                            sources.append({ "path": source_path, "name": source_name })
         
     | 
| 85 | 
         
            +
                    else:
         
     | 
| 86 | 
         
            +
                        sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    for source in sources:
         
     | 
| 89 | 
         
            +
                        source_path = source["path"]
         
     | 
| 90 | 
         
            +
                        source_name = source["name"]
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                        result = transcriber.transcribe_file(model, source_path, temperature=temperature, 
         
     | 
| 93 | 
         
            +
                                                            vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size, 
         
     | 
| 94 | 
         
            +
                                                            vadPadding=vad_padding, **args)
         
     | 
| 95 | 
         
            +
                        
         
     | 
| 96 | 
         
            +
                        transcriber.write_result(result, source_name, output_dir)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                transcriber.clear_cache()
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def uri_validator(x):
         
     | 
| 101 | 
         
            +
                try:
         
     | 
| 102 | 
         
            +
                    result = urlparse(x)
         
     | 
| 103 | 
         
            +
                    return all([result.scheme, result.netloc])
         
     | 
| 104 | 
         
            +
                except:
         
     | 
| 105 | 
         
            +
                    return False
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 108 | 
         
            +
                cli()
         
     | 
    	
        src/download.py
    CHANGED
    
    | 
         @@ -1,4 +1,5 @@ 
     | 
|
| 1 | 
         
             
            from tempfile import mkdtemp
         
     | 
| 
         | 
|
| 2 | 
         
             
            from yt_dlp import YoutubeDL
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            import yt_dlp
         
     | 
| 
         @@ -13,25 +14,28 @@ class FilenameCollectorPP(PostProcessor): 
     | 
|
| 13 | 
         
             
                    self.filenames.append(information["filepath"])
         
     | 
| 14 | 
         
             
                    return [], information
         
     | 
| 15 | 
         | 
| 16 | 
         
            -
            def download_url(url: str, maxDuration: int = None):
         
     | 
| 17 | 
         
             
                try:
         
     | 
| 18 | 
         
            -
                    return _perform_download(url, maxDuration=maxDuration)
         
     | 
| 19 | 
         
             
                except yt_dlp.utils.DownloadError as e:
         
     | 
| 20 | 
         
             
                    # In case of an OS error, try again with a different output template
         
     | 
| 21 | 
         
             
                    if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
         
     | 
| 22 | 
         
             
                        return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
         
     | 
| 23 | 
         
             
                    pass
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
            def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None):
         
     | 
| 26 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 27 | 
         | 
| 28 | 
         
             
                ydl_opts = {
         
     | 
| 29 | 
         
             
                    "format": "bestaudio/best",
         
     | 
| 30 | 
         
            -
                    'playlist_items': '1',
         
     | 
| 31 | 
         
             
                    'paths': {
         
     | 
| 32 | 
         
             
                        'home': destinationDirectory
         
     | 
| 33 | 
         
             
                    }
         
     | 
| 34 | 
         
             
                }
         
     | 
| 
         | 
|
| 
         | 
|
| 35 | 
         | 
| 36 | 
         
             
                # Add output template if specified
         
     | 
| 37 | 
         
             
                if outputTemplate:
         
     | 
| 
         @@ -53,8 +57,11 @@ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = N 
     | 
|
| 53 | 
         
             
                if len(filename_collector.filenames) <= 0:
         
     | 
| 54 | 
         
             
                    raise Exception("Cannot download " + url)
         
     | 
| 55 | 
         | 
| 56 | 
         
            -
                result =  
     | 
| 57 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 58 | 
         | 
| 59 | 
         
             
                return result 
         
     | 
| 60 | 
         | 
| 
         | 
|
| 1 | 
         
             
            from tempfile import mkdtemp
         
     | 
| 2 | 
         
            +
            from typing import List
         
     | 
| 3 | 
         
             
            from yt_dlp import YoutubeDL
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            import yt_dlp
         
     | 
| 
         | 
|
| 14 | 
         
             
                    self.filenames.append(information["filepath"])
         
     | 
| 15 | 
         
             
                    return [], information
         
     | 
| 16 | 
         | 
| 17 | 
         
            +
            def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]: 
         
     | 
| 18 | 
         
             
                try:
         
     | 
| 19 | 
         
            +
                    return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
         
     | 
| 20 | 
         
             
                except yt_dlp.utils.DownloadError as e:
         
     | 
| 21 | 
         
             
                    # In case of an OS error, try again with a different output template
         
     | 
| 22 | 
         
             
                    if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
         
     | 
| 23 | 
         
             
                        return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
         
     | 
| 24 | 
         
             
                    pass
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
            def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
         
     | 
| 27 | 
         
            +
                # Create a temporary directory to store the downloaded files
         
     | 
| 28 | 
         
            +
                if destinationDirectory is None:
         
     | 
| 29 | 
         
            +
                    destinationDirectory = mkdtemp()
         
     | 
| 30 | 
         | 
| 31 | 
         
             
                ydl_opts = {
         
     | 
| 32 | 
         
             
                    "format": "bestaudio/best",
         
     | 
| 
         | 
|
| 33 | 
         
             
                    'paths': {
         
     | 
| 34 | 
         
             
                        'home': destinationDirectory
         
     | 
| 35 | 
         
             
                    }
         
     | 
| 36 | 
         
             
                }
         
     | 
| 37 | 
         
            +
                if (playlistItems):
         
     | 
| 38 | 
         
            +
                    ydl_opts['playlist_items'] = playlistItems
         
     | 
| 39 | 
         | 
| 40 | 
         
             
                # Add output template if specified
         
     | 
| 41 | 
         
             
                if outputTemplate:
         
     | 
| 
         | 
|
| 57 | 
         
             
                if len(filename_collector.filenames) <= 0:
         
     | 
| 58 | 
         
             
                    raise Exception("Cannot download " + url)
         
     | 
| 59 | 
         | 
| 60 | 
         
            +
                result = []
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                for filename in filename_collector.filenames:
         
     | 
| 63 | 
         
            +
                    result.append(filename)
         
     | 
| 64 | 
         
            +
                    print("Downloaded " + filename)
         
     | 
| 65 | 
         | 
| 66 | 
         
             
                return result 
         
     | 
| 67 | 
         | 
    	
        src/vad.py
    CHANGED
    
    | 
         @@ -188,7 +188,9 @@ class AbstractTranscription(ABC): 
     | 
|
| 188 | 
         | 
| 189 | 
         
             
                        result.append(current_segment)
         
     | 
| 190 | 
         | 
| 191 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 192 | 
         | 
| 193 | 
         
             
                    # Also include total duration if specified
         
     | 
| 194 | 
         
             
                    if (total_duration is not None):
         
     | 
| 
         | 
|
| 188 | 
         | 
| 189 | 
         
             
                        result.append(current_segment)
         
     | 
| 190 | 
         | 
| 191 | 
         
            +
                    # Add last segment
         
     | 
| 192 | 
         
            +
                    last_segment = segments[-1]
         
     | 
| 193 | 
         
            +
                    result.append(last_segment)
         
     | 
| 194 | 
         | 
| 195 | 
         
             
                    # Also include total duration if specified
         
     | 
| 196 | 
         
             
                    if (total_duration is not None):
         
     |