Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Ensure progress bar works for multiple files
Browse files- app.py +23 -6
- src/source.py +22 -12
    	
        app.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ import numpy as np | |
| 12 |  | 
| 13 | 
             
            import torch
         | 
| 14 | 
             
            from src.config import ApplicationConfig
         | 
| 15 | 
            -
            from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
         | 
| 16 | 
             
            from src.modelCache import ModelCache
         | 
| 17 | 
             
            from src.source import get_audio_source_collection
         | 
| 18 | 
             
            from src.vadParallel import ParallelContext, ParallelTranscription
         | 
| @@ -135,9 +135,17 @@ class WhisperTranscriber: | |
| 135 |  | 
| 136 | 
             
                            outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
         | 
| 137 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 138 | 
             
                            # Execute whisper
         | 
| 139 | 
             
                            for source in sources:
         | 
| 140 | 
             
                                source_prefix = ""
         | 
|  | |
| 141 |  | 
| 142 | 
             
                                if (len(sources) > 1):
         | 
| 143 | 
             
                                    # Prefix (minimum 2 digits)
         | 
| @@ -145,10 +153,18 @@ class WhisperTranscriber: | |
| 145 | 
             
                                    source_prefix = str(source_index).zfill(2) + "_"
         | 
| 146 | 
             
                                    print("Transcribing ", source.source_path)
         | 
| 147 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 148 | 
             
                                # Transcribe
         | 
| 149 | 
            -
                                result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,  | 
| 150 | 
             
                                filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
         | 
| 151 |  | 
|  | |
|  | |
|  | |
| 152 | 
             
                                source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
         | 
| 153 |  | 
| 154 | 
             
                                if len(sources) > 1:
         | 
| @@ -209,19 +225,20 @@ class WhisperTranscriber: | |
| 209 |  | 
| 210 | 
             
                def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None, 
         | 
| 211 | 
             
                                    vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, 
         | 
| 212 | 
            -
                                     | 
| 213 |  | 
| 214 | 
             
                    initial_prompt = decodeOptions.pop('initial_prompt', None)
         | 
| 215 |  | 
|  | |
|  | |
|  | |
|  | |
| 216 | 
             
                    if ('task' in decodeOptions):
         | 
| 217 | 
             
                        task = decodeOptions.pop('task')
         | 
| 218 |  | 
| 219 | 
             
                    # Callable for processing an audio file
         | 
| 220 | 
             
                    whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
         | 
| 221 |  | 
| 222 | 
            -
                    # A listener that will report progress to Gradio
         | 
| 223 | 
            -
                    progressListener = self._create_progress_listener(progress)
         | 
| 224 | 
            -
             | 
| 225 | 
             
                    # The results
         | 
| 226 | 
             
                    if (vad == 'silero-vad'):
         | 
| 227 | 
             
                        # Silero VAD where non-speech gaps are transcribed
         | 
|  | |
| 12 |  | 
| 13 | 
             
            import torch
         | 
| 14 | 
             
            from src.config import ApplicationConfig
         | 
| 15 | 
            +
            from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
         | 
| 16 | 
             
            from src.modelCache import ModelCache
         | 
| 17 | 
             
            from src.source import get_audio_source_collection
         | 
| 18 | 
             
            from src.vadParallel import ParallelContext, ParallelTranscription
         | 
|  | |
| 135 |  | 
| 136 | 
             
                            outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
         | 
| 137 |  | 
| 138 | 
            +
                            # Progress
         | 
| 139 | 
            +
                            total_duration = sum([source.get_audio_duration() for source in sources])
         | 
| 140 | 
            +
                            current_progress = 0
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                            # A listener that will report progress to Gradio
         | 
| 143 | 
            +
                            root_progress_listener = self._create_progress_listener(progress)
         | 
| 144 | 
            +
             | 
| 145 | 
             
                            # Execute whisper
         | 
| 146 | 
             
                            for source in sources:
         | 
| 147 | 
             
                                source_prefix = ""
         | 
| 148 | 
            +
                                source_audio_duration = source.get_audio_duration()
         | 
| 149 |  | 
| 150 | 
             
                                if (len(sources) > 1):
         | 
| 151 | 
             
                                    # Prefix (minimum 2 digits)
         | 
|  | |
| 153 | 
             
                                    source_prefix = str(source_index).zfill(2) + "_"
         | 
| 154 | 
             
                                    print("Transcribing ", source.source_path)
         | 
| 155 |  | 
| 156 | 
            +
                                scaled_progress_listener = SubTaskProgressListener(root_progress_listener, 
         | 
| 157 | 
            +
                                                               base_task_total=total_duration,
         | 
| 158 | 
            +
                                                               sub_task_start=current_progress,
         | 
| 159 | 
            +
                                                               sub_task_total=source_audio_duration)
         | 
| 160 | 
            +
             | 
| 161 | 
             
                                # Transcribe
         | 
| 162 | 
            +
                                result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, scaled_progress_listener, **decodeOptions)
         | 
| 163 | 
             
                                filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
         | 
| 164 |  | 
| 165 | 
            +
                                # Update progress
         | 
| 166 | 
            +
                                current_progress += source_audio_duration
         | 
| 167 | 
            +
             | 
| 168 | 
             
                                source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
         | 
| 169 |  | 
| 170 | 
             
                                if len(sources) > 1:
         | 
|  | |
| 225 |  | 
| 226 | 
             
                def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None, 
         | 
| 227 | 
             
                                    vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, 
         | 
| 228 | 
            +
                                    progressListener: ProgressListener = None, **decodeOptions: dict):
         | 
| 229 |  | 
| 230 | 
             
                    initial_prompt = decodeOptions.pop('initial_prompt', None)
         | 
| 231 |  | 
| 232 | 
            +
                    if progressListener is None:
         | 
| 233 | 
            +
                        # Default progress listener
         | 
| 234 | 
            +
                        progressListener = ProgressListener()
         | 
| 235 | 
            +
             | 
| 236 | 
             
                    if ('task' in decodeOptions):
         | 
| 237 | 
             
                        task = decodeOptions.pop('task')
         | 
| 238 |  | 
| 239 | 
             
                    # Callable for processing an audio file
         | 
| 240 | 
             
                    whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
         | 
| 241 |  | 
|  | |
|  | |
|  | |
| 242 | 
             
                    # The results
         | 
| 243 | 
             
                    if (vad == 'silero-vad'):
         | 
| 244 | 
             
                        # Silero VAD where non-speech gaps are transcribed
         | 
    	
        src/source.py
    CHANGED
    
    | @@ -12,15 +12,22 @@ from src.download import ExceededMaximumDuration, download_url | |
| 12 | 
             
            MAX_FILE_PREFIX_LENGTH = 17
         | 
| 13 |  | 
| 14 | 
             
            class AudioSource:
         | 
| 15 | 
            -
                def __init__(self, source_path, source_name = None):
         | 
| 16 | 
             
                    self.source_path = source_path
         | 
| 17 | 
             
                    self.source_name = source_name
         | 
|  | |
| 18 |  | 
| 19 | 
             
                    # Load source name if not provided
         | 
| 20 | 
             
                    if (self.source_name is None):
         | 
| 21 | 
             
                        file_path = pathlib.Path(self.source_path)
         | 
| 22 | 
             
                        self.source_name = file_path.name
         | 
| 23 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
                def get_full_name(self):
         | 
| 25 | 
             
                    return self.source_name
         | 
| 26 |  | 
| @@ -53,18 +60,21 @@ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneDat | |
| 53 | 
             
                    if (microphoneData is not None):
         | 
| 54 | 
             
                        output.append(AudioSource(microphoneData))
         | 
| 55 |  | 
| 56 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 57 |  | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
                         | 
| 62 | 
            -
                        total_duration += float(audioDuration)
         | 
| 63 |  | 
| 64 | 
            -
                    # Ensure the total duration of the audio is not too long
         | 
| 65 | 
            -
                    if input_audio_max_duration > 0:
         | 
| 66 | 
            -
                        if float(total_duration) > input_audio_max_duration:
         | 
| 67 | 
            -
                            raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
         | 
| 68 | 
            -
                            
         | 
| 69 | 
             
                # Return a list of audio sources
         | 
| 70 | 
             
                return output
         | 
|  | |
| 12 | 
             
            MAX_FILE_PREFIX_LENGTH = 17
         | 
| 13 |  | 
| 14 | 
             
            class AudioSource:
         | 
| 15 | 
            +
                def __init__(self, source_path, source_name = None, audio_duration = None):
         | 
| 16 | 
             
                    self.source_path = source_path
         | 
| 17 | 
             
                    self.source_name = source_name
         | 
| 18 | 
            +
                    self._audio_duration = audio_duration
         | 
| 19 |  | 
| 20 | 
             
                    # Load source name if not provided
         | 
| 21 | 
             
                    if (self.source_name is None):
         | 
| 22 | 
             
                        file_path = pathlib.Path(self.source_path)
         | 
| 23 | 
             
                        self.source_name = file_path.name
         | 
| 24 |  | 
| 25 | 
            +
                def get_audio_duration(self):
         | 
| 26 | 
            +
                    if self._audio_duration is None:
         | 
| 27 | 
            +
                        self._audio_duration = float(ffmpeg.probe(self.source_path)["format"]["duration"])
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    return self._audio_duration
         | 
| 30 | 
            +
             | 
| 31 | 
             
                def get_full_name(self):
         | 
| 32 | 
             
                    return self.source_name
         | 
| 33 |  | 
|  | |
| 60 | 
             
                    if (microphoneData is not None):
         | 
| 61 | 
             
                        output.append(AudioSource(microphoneData))
         | 
| 62 |  | 
| 63 | 
            +
                total_duration = 0
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                # Calculate total audio length. We do this even if input_audio_max_duration
         | 
| 66 | 
            +
                # is disabled to ensure that all the audio files are valid.
         | 
| 67 | 
            +
                for source in output:
         | 
| 68 | 
            +
                    audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
         | 
| 69 | 
            +
                    total_duration += float(audioDuration)
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
                    # Save audio duration
         | 
| 72 | 
            +
                    source._audio_duration = float(audioDuration)
         | 
| 73 |  | 
| 74 | 
            +
                # Ensure the total duration of the audio is not too long
         | 
| 75 | 
            +
                if input_audio_max_duration > 0:
         | 
| 76 | 
            +
                    if float(total_duration) > input_audio_max_duration:
         | 
| 77 | 
            +
                        raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
         | 
|  | |
| 78 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 79 | 
             
                # Return a list of audio sources
         | 
| 80 | 
             
                return output
         |