Spaces:
Running
Running
Merge pull request #290 from jhj0517/fix/defaults
Browse files
modules/translation/nllb_inference.py
CHANGED
|
@@ -35,7 +35,7 @@ class NLLBInference(TranslationBase):
|
|
| 35 |
model_size: str,
|
| 36 |
src_lang: str,
|
| 37 |
tgt_lang: str,
|
| 38 |
-
progress: gr.Progress
|
| 39 |
):
|
| 40 |
if model_size != self.current_model_size or self.model is None:
|
| 41 |
print("\nInitializing NLLB Model..\n")
|
|
|
|
| 35 |
model_size: str,
|
| 36 |
src_lang: str,
|
| 37 |
tgt_lang: str,
|
| 38 |
+
progress: gr.Progress = gr.Progress()
|
| 39 |
):
|
| 40 |
if model_size != self.current_model_size or self.model is None:
|
| 41 |
print("\nInitializing NLLB Model..\n")
|
modules/translation/translation_base.py
CHANGED
|
@@ -37,7 +37,7 @@ class TranslationBase(ABC):
|
|
| 37 |
model_size: str,
|
| 38 |
src_lang: str,
|
| 39 |
tgt_lang: str,
|
| 40 |
-
progress: gr.Progress
|
| 41 |
):
|
| 42 |
pass
|
| 43 |
|
|
|
|
| 37 |
model_size: str,
|
| 38 |
src_lang: str,
|
| 39 |
tgt_lang: str,
|
| 40 |
+
progress: gr.Progress = gr.Progress()
|
| 41 |
):
|
| 42 |
pass
|
| 43 |
|
modules/whisper/faster_whisper_inference.py
CHANGED
|
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
|
|
| 40 |
|
| 41 |
def transcribe(self,
|
| 42 |
audio: Union[str, BinaryIO, np.ndarray],
|
| 43 |
-
progress: gr.Progress,
|
| 44 |
*whisper_params,
|
| 45 |
) -> Tuple[List[dict], float]:
|
| 46 |
"""
|
|
@@ -126,7 +126,7 @@ class FasterWhisperInference(WhisperBase):
|
|
| 126 |
def update_model(self,
|
| 127 |
model_size: str,
|
| 128 |
compute_type: str,
|
| 129 |
-
progress: gr.Progress
|
| 130 |
):
|
| 131 |
"""
|
| 132 |
Update current model setting
|
|
|
|
| 40 |
|
| 41 |
def transcribe(self,
|
| 42 |
audio: Union[str, BinaryIO, np.ndarray],
|
| 43 |
+
progress: gr.Progress = gr.Progress(),
|
| 44 |
*whisper_params,
|
| 45 |
) -> Tuple[List[dict], float]:
|
| 46 |
"""
|
|
|
|
| 126 |
def update_model(self,
|
| 127 |
model_size: str,
|
| 128 |
compute_type: str,
|
| 129 |
+
progress: gr.Progress = gr.Progress()
|
| 130 |
):
|
| 131 |
"""
|
| 132 |
Update current model setting
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
|
@@ -39,7 +39,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
| 39 |
|
| 40 |
def transcribe(self,
|
| 41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
| 42 |
-
progress: gr.Progress,
|
| 43 |
*whisper_params,
|
| 44 |
) -> Tuple[List[dict], float]:
|
| 45 |
"""
|
|
@@ -98,7 +98,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
| 98 |
def update_model(self,
|
| 99 |
model_size: str,
|
| 100 |
compute_type: str,
|
| 101 |
-
progress: gr.Progress,
|
| 102 |
):
|
| 103 |
"""
|
| 104 |
Update current model setting
|
|
|
|
| 39 |
|
| 40 |
def transcribe(self,
|
| 41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
| 42 |
+
progress: gr.Progress = gr.Progress(),
|
| 43 |
*whisper_params,
|
| 44 |
) -> Tuple[List[dict], float]:
|
| 45 |
"""
|
|
|
|
| 98 |
def update_model(self,
|
| 99 |
model_size: str,
|
| 100 |
compute_type: str,
|
| 101 |
+
progress: gr.Progress = gr.Progress(),
|
| 102 |
):
|
| 103 |
"""
|
| 104 |
Update current model setting
|
modules/whisper/whisper_Inference.py
CHANGED
|
@@ -28,7 +28,7 @@ class WhisperInference(WhisperBase):
|
|
| 28 |
|
| 29 |
def transcribe(self,
|
| 30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
| 31 |
-
progress: gr.Progress,
|
| 32 |
*whisper_params,
|
| 33 |
) -> Tuple[List[dict], float]:
|
| 34 |
"""
|
|
@@ -79,7 +79,7 @@ class WhisperInference(WhisperBase):
|
|
| 79 |
def update_model(self,
|
| 80 |
model_size: str,
|
| 81 |
compute_type: str,
|
| 82 |
-
progress: gr.Progress,
|
| 83 |
):
|
| 84 |
"""
|
| 85 |
Update current model setting
|
|
|
|
| 28 |
|
| 29 |
def transcribe(self,
|
| 30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
| 31 |
+
progress: gr.Progress = gr.Progress(),
|
| 32 |
*whisper_params,
|
| 33 |
) -> Tuple[List[dict], float]:
|
| 34 |
"""
|
|
|
|
| 79 |
def update_model(self,
|
| 80 |
model_size: str,
|
| 81 |
compute_type: str,
|
| 82 |
+
progress: gr.Progress = gr.Progress(),
|
| 83 |
):
|
| 84 |
"""
|
| 85 |
Update current model setting
|
modules/whisper/whisper_base.py
CHANGED
|
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
|
|
| 53 |
@abstractmethod
|
| 54 |
def transcribe(self,
|
| 55 |
audio: Union[str, BinaryIO, np.ndarray],
|
| 56 |
-
progress: gr.Progress,
|
| 57 |
*whisper_params,
|
| 58 |
):
|
| 59 |
"""Inference whisper model to transcribe"""
|
|
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
|
|
| 63 |
def update_model(self,
|
| 64 |
model_size: str,
|
| 65 |
compute_type: str,
|
| 66 |
-
progress: gr.Progress
|
| 67 |
):
|
| 68 |
"""Initialize whisper model"""
|
| 69 |
pass
|
|
@@ -171,10 +171,10 @@ class WhisperBase(ABC):
|
|
| 171 |
return result, elapsed_time
|
| 172 |
|
| 173 |
def transcribe_file(self,
|
| 174 |
-
files:
|
| 175 |
-
input_folder_path: str,
|
| 176 |
-
file_format: str,
|
| 177 |
-
add_timestamp: bool,
|
| 178 |
progress=gr.Progress(),
|
| 179 |
*whisper_params,
|
| 180 |
) -> list:
|
|
@@ -250,8 +250,8 @@ class WhisperBase(ABC):
|
|
| 250 |
|
| 251 |
def transcribe_mic(self,
|
| 252 |
mic_audio: str,
|
| 253 |
-
file_format: str,
|
| 254 |
-
add_timestamp: bool,
|
| 255 |
progress=gr.Progress(),
|
| 256 |
*whisper_params,
|
| 257 |
) -> list:
|
|
@@ -306,8 +306,8 @@ class WhisperBase(ABC):
|
|
| 306 |
|
| 307 |
def transcribe_youtube(self,
|
| 308 |
youtube_link: str,
|
| 309 |
-
file_format: str,
|
| 310 |
-
add_timestamp: bool,
|
| 311 |
progress=gr.Progress(),
|
| 312 |
*whisper_params,
|
| 313 |
) -> list:
|
|
@@ -411,11 +411,12 @@ class WhisperBase(ABC):
|
|
| 411 |
else:
|
| 412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
| 413 |
|
| 414 |
-
|
|
|
|
| 415 |
content = get_srt(transcribed_segments)
|
| 416 |
output_path += '.srt'
|
| 417 |
|
| 418 |
-
elif file_format == "
|
| 419 |
content = get_vtt(transcribed_segments)
|
| 420 |
output_path += '.vtt'
|
| 421 |
|
|
|
|
| 53 |
@abstractmethod
|
| 54 |
def transcribe(self,
|
| 55 |
audio: Union[str, BinaryIO, np.ndarray],
|
| 56 |
+
progress: gr.Progress = gr.Progress(),
|
| 57 |
*whisper_params,
|
| 58 |
):
|
| 59 |
"""Inference whisper model to transcribe"""
|
|
|
|
| 63 |
def update_model(self,
|
| 64 |
model_size: str,
|
| 65 |
compute_type: str,
|
| 66 |
+
progress: gr.Progress = gr.Progress()
|
| 67 |
):
|
| 68 |
"""Initialize whisper model"""
|
| 69 |
pass
|
|
|
|
| 171 |
return result, elapsed_time
|
| 172 |
|
| 173 |
def transcribe_file(self,
|
| 174 |
+
files: Optional[List] = None,
|
| 175 |
+
input_folder_path: Optional[str] = None,
|
| 176 |
+
file_format: str = "SRT",
|
| 177 |
+
add_timestamp: bool = True,
|
| 178 |
progress=gr.Progress(),
|
| 179 |
*whisper_params,
|
| 180 |
) -> list:
|
|
|
|
| 250 |
|
| 251 |
def transcribe_mic(self,
|
| 252 |
mic_audio: str,
|
| 253 |
+
file_format: str = "SRT",
|
| 254 |
+
add_timestamp: bool = True,
|
| 255 |
progress=gr.Progress(),
|
| 256 |
*whisper_params,
|
| 257 |
) -> list:
|
|
|
|
| 306 |
|
| 307 |
def transcribe_youtube(self,
|
| 308 |
youtube_link: str,
|
| 309 |
+
file_format: str = "SRT",
|
| 310 |
+
add_timestamp: bool = True,
|
| 311 |
progress=gr.Progress(),
|
| 312 |
*whisper_params,
|
| 313 |
) -> list:
|
|
|
|
| 411 |
else:
|
| 412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
| 413 |
|
| 414 |
+
file_format = file_format.strip().lower()
|
| 415 |
+
if file_format == "srt":
|
| 416 |
content = get_srt(transcribed_segments)
|
| 417 |
output_path += '.srt'
|
| 418 |
|
| 419 |
+
elif file_format == "webvtt":
|
| 420 |
content = get_vtt(transcribed_segments)
|
| 421 |
output_path += '.vtt'
|
| 422 |
|