Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import faster_whisper | |
| from typing import List, Union, Optional, NamedTuple | |
| import torch | |
| import numpy as np | |
| import tqdm | |
| from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram | |
| from whisperx.types import TranscriptionResult, SingleSegment | |
| from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens | |
| class VadFreeFasterWhisperPipeline(FasterWhisperPipeline): | |
| """ | |
| FasterWhisperModel without VAD | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| options: NamedTuple, | |
| tokenizer=None, | |
| device: Union[int, str, "torch.device"] = -1, | |
| framework="pt", | |
| language: Optional[str] = None, | |
| suppress_numerals: bool = False, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the VadFreeFasterWhisperPipeline. | |
| Args: | |
| model: The Whisper model instance. | |
| options: Transcription options. | |
| tokenizer: The tokenizer instance. | |
| device: Device to run the model on. | |
| framework: The framework to use ('pt' for PyTorch). | |
| language: The language for transcription. | |
| suppress_numerals: Whether to suppress numeral tokens. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| None | |
| """ | |
| super().__init__( | |
| model=model, | |
| vad=None, | |
| vad_params={}, | |
| options=options, | |
| tokenizer=tokenizer, | |
| device=device, | |
| framework=framework, | |
| language=language, | |
| suppress_numerals=suppress_numerals, | |
| **kwargs, | |
| ) | |
| def detect_language(self, audio: np.ndarray): | |
| """ | |
| Detect the language of the audio. | |
| Args: | |
| audio (np.ndarray): The input audio signal. | |
| Returns: | |
| tuple: Detected language and its probability. | |
| """ | |
| model_n_mels = self.model.feat_kwargs.get("feature_size") | |
| if audio.shape[0] > N_SAMPLES: | |
| # Randomly sample N_SAMPLES from the audio array | |
| start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES) | |
| audio_sample = audio[start_index : start_index + N_SAMPLES] | |
| else: | |
| audio_sample = audio[:N_SAMPLES] | |
| padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0] | |
| segment = log_mel_spectrogram( | |
| audio_sample, | |
| n_mels=model_n_mels if model_n_mels is not None else 80, | |
| padding=padding, | |
| ) | |
| encoder_output = self.model.encode(segment) | |
| results = self.model.model.detect_language(encoder_output) | |
| language_token, language_probability = results[0][0] | |
| language = language_token[2:-2] | |
| return language, language_probability | |
| def transcribe( | |
| self, | |
| audio: Union[str, np.ndarray], | |
| vad_segments: List[dict], | |
| batch_size=None, | |
| num_workers=0, | |
| language=None, | |
| task=None, | |
| chunk_size=30, | |
| print_progress=False, | |
| combined_progress=False, | |
| ) -> TranscriptionResult: | |
| """ | |
| Transcribe the audio into text. | |
| Args: | |
| audio (Union[str, np.ndarray]): The input audio signal or path to audio file. | |
| vad_segments (List[dict]): List of VAD segments. | |
| batch_size (int, optional): Batch size for transcription. Defaults to None. | |
| num_workers (int, optional): Number of workers for loading data. Defaults to 0. | |
| language (str, optional): Language for transcription. Defaults to None. | |
| task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None. | |
| chunk_size (int, optional): Size of chunks for processing. Defaults to 30. | |
| print_progress (bool, optional): Whether to print progress. Defaults to False. | |
| combined_progress (bool, optional): Whether to combine progress. Defaults to False. | |
| Returns: | |
| TranscriptionResult: The transcription result containing segments and language. | |
| """ | |
| if isinstance(audio, str): | |
| audio = load_audio(audio) | |
| def data(audio, segments): | |
| for seg in segments: | |
| f1 = int(seg["start"] * SAMPLE_RATE) | |
| f2 = int(seg["end"] * SAMPLE_RATE) | |
| yield {"inputs": audio[f1:f2]} | |
| if self.tokenizer is None: | |
| language = language or self.detect_language(audio) | |
| task = task or "transcribe" | |
| self.tokenizer = faster_whisper.tokenizer.Tokenizer( | |
| self.model.hf_tokenizer, | |
| self.model.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| else: | |
| language = language or self.tokenizer.language_code | |
| task = task or self.tokenizer.task | |
| if task != self.tokenizer.task or language != self.tokenizer.language_code: | |
| self.tokenizer = faster_whisper.tokenizer.Tokenizer( | |
| self.model.hf_tokenizer, | |
| self.model.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| if self.suppress_numerals: | |
| previous_suppress_tokens = self.options.suppress_tokens | |
| numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) | |
| new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens | |
| new_suppressed_tokens = list(set(new_suppressed_tokens)) | |
| self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) | |
| segments: List[SingleSegment] = [] | |
| batch_size = batch_size or self._batch_size | |
| total_segments = len(vad_segments) | |
| progress = tqdm.tqdm(total=total_segments, desc="Transcribing") | |
| for idx, out in enumerate( | |
| self.__call__( | |
| data(audio, vad_segments), | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| ) | |
| ): | |
| if print_progress: | |
| progress.update(1) | |
| text = out["text"] | |
| if batch_size in [0, 1, None]: | |
| text = text[0] | |
| segments.append( | |
| { | |
| "text": text, | |
| "start": round(vad_segments[idx]["start"], 3), | |
| "end": round(vad_segments[idx]["end"], 3), | |
| "speaker": vad_segments[idx].get("speaker", None), | |
| } | |
| ) | |
| # revert the tokenizer if multilingual inference is enabled | |
| if self.preset_language is None: | |
| self.tokenizer = None | |
| # revert suppressed tokens if suppress_numerals is enabled | |
| if self.suppress_numerals: | |
| self.options = self.options._replace( | |
| suppress_tokens=previous_suppress_tokens | |
| ) | |
| return {"segments": segments, "language": language} | |
| def load_asr_model( | |
| whisper_arch: str, | |
| device: str, | |
| device_index: int = 0, | |
| compute_type: str = "float16", | |
| asr_options: Optional[dict] = None, | |
| language: Optional[str] = None, | |
| vad_model=None, | |
| vad_options=None, | |
| model: Optional[WhisperModel] = None, | |
| task: str = "transcribe", | |
| download_root: Optional[str] = None, | |
| threads: int = 4, | |
| ) -> VadFreeFasterWhisperPipeline: | |
| """ | |
| Load a Whisper model for inference. | |
| Args: | |
| whisper_arch (str): The name of the Whisper model to load. | |
| device (str): The device to load the model on. | |
| device_index (int, optional): The device index. Defaults to 0. | |
| compute_type (str, optional): The compute type to use for the model. Defaults to "float16". | |
| asr_options (Optional[dict], optional): Options for ASR. Defaults to None. | |
| language (Optional[str], optional): The language of the model. Defaults to None. | |
| vad_model: The VAD model instance. Defaults to None. | |
| vad_options: Options for VAD. Defaults to None. | |
| model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None. | |
| task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe". | |
| download_root (Optional[str], optional): The root directory to download the model to. Defaults to None. | |
| threads (int, optional): The number of CPU threads to use per worker. Defaults to 4. | |
| Returns: | |
| VadFreeFasterWhisperPipeline: The loaded Whisper pipeline. | |
| Raises: | |
| ValueError: If the whisper architecture is not recognized. | |
| """ | |
| if whisper_arch.endswith(".en"): | |
| language = "en" | |
| model = model or WhisperModel( | |
| whisper_arch, | |
| device=device, | |
| device_index=device_index, | |
| compute_type=compute_type, | |
| download_root=download_root, | |
| cpu_threads=threads, | |
| ) | |
| if language is not None: | |
| tokenizer = faster_whisper.tokenizer.Tokenizer( | |
| model.hf_tokenizer, | |
| model.model.is_multilingual, | |
| task=task, | |
| language=language, | |
| ) | |
| else: | |
| print( | |
| "No language specified, language will be detected for each audio file (increases inference time)." | |
| ) | |
| tokenizer = None | |
| default_asr_options = { | |
| "beam_size": 5, | |
| "best_of": 5, | |
| "patience": 1, | |
| "length_penalty": 1, | |
| "repetition_penalty": 1, | |
| "no_repeat_ngram_size": 0, | |
| "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], | |
| "compression_ratio_threshold": 2.4, | |
| "log_prob_threshold": -1.0, | |
| "no_speech_threshold": 0.6, | |
| "condition_on_previous_text": False, | |
| "prompt_reset_on_temperature": 0.5, | |
| "initial_prompt": None, | |
| "prefix": None, | |
| "suppress_blank": True, | |
| "suppress_tokens": [-1], | |
| "without_timestamps": True, | |
| "max_initial_timestamp": 0.0, | |
| "word_timestamps": False, | |
| "prepend_punctuations": "\"'“¿([{-", | |
| "append_punctuations": "\"'.。,,!!??::”)]}、", | |
| "suppress_numerals": False, | |
| "max_new_tokens": None, | |
| "clip_timestamps": None, | |
| "hallucination_silence_threshold": None, | |
| } | |
| if asr_options is not None: | |
| default_asr_options.update(asr_options) | |
| suppress_numerals = default_asr_options["suppress_numerals"] | |
| del default_asr_options["suppress_numerals"] | |
| default_asr_options = faster_whisper.transcribe.TranscriptionOptions( | |
| **default_asr_options | |
| ) | |
| return VadFreeFasterWhisperPipeline( | |
| model=model, | |
| options=default_asr_options, | |
| tokenizer=tokenizer, | |
| language=language, | |
| suppress_numerals=suppress_numerals, | |
| ) | |