Spaces:
Running
Running
jhj0517
commited on
Commit
·
a3de454
1
Parent(s):
a85d7d2
Rename class & file
Browse files
modules/whisper/{whisper_base.py → base_transcription_pipeline.py}
RENAMED
|
@@ -24,7 +24,7 @@ from modules.diarize.diarizer import Diarizer
|
|
| 24 |
from modules.vad.silero_vad import SileroVAD
|
| 25 |
|
| 26 |
|
| 27 |
-
class
|
| 28 |
def __init__(self,
|
| 29 |
model_dir: str = WHISPER_MODELS_DIR,
|
| 30 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
@@ -464,7 +464,7 @@ class WhisperBase(ABC):
|
|
| 464 |
if torch.cuda.is_available():
|
| 465 |
return "cuda"
|
| 466 |
elif torch.backends.mps.is_available():
|
| 467 |
-
if not
|
| 468 |
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
| 469 |
return "cpu"
|
| 470 |
return "mps"
|
|
|
|
| 24 |
from modules.vad.silero_vad import SileroVAD
|
| 25 |
|
| 26 |
|
| 27 |
+
class BaseTranscriptionPipeline(ABC):
|
| 28 |
def __init__(self,
|
| 29 |
model_dir: str = WHISPER_MODELS_DIR,
|
| 30 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
|
| 464 |
if torch.cuda.is_available():
|
| 465 |
return "cuda"
|
| 466 |
elif torch.backends.mps.is_available():
|
| 467 |
+
if not BaseTranscriptionPipeline.is_sparse_api_supported():
|
| 468 |
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
| 469 |
return "cpu"
|
| 470 |
return "mps"
|
modules/whisper/faster_whisper_inference.py
CHANGED
|
@@ -13,10 +13,10 @@ from argparse import Namespace
|
|
| 13 |
|
| 14 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
| 15 |
from modules.whisper.data_classes import *
|
| 16 |
-
from modules.whisper.
|
| 17 |
|
| 18 |
|
| 19 |
-
class FasterWhisperInference(
|
| 20 |
def __init__(self,
|
| 21 |
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
| 22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
|
| 13 |
|
| 14 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
| 15 |
from modules.whisper.data_classes import *
|
| 16 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
| 17 |
|
| 18 |
|
| 19 |
+
class FasterWhisperInference(BaseTranscriptionPipeline):
|
| 20 |
def __init__(self,
|
| 21 |
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
| 22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
|
@@ -13,10 +13,10 @@ from argparse import Namespace
|
|
| 13 |
|
| 14 |
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
| 15 |
from modules.whisper.data_classes import *
|
| 16 |
-
from modules.whisper.
|
| 17 |
|
| 18 |
|
| 19 |
-
class InsanelyFastWhisperInference(
|
| 20 |
def __init__(self,
|
| 21 |
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
| 22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
|
| 13 |
|
| 14 |
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
| 15 |
from modules.whisper.data_classes import *
|
| 16 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
| 17 |
|
| 18 |
|
| 19 |
+
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
|
| 20 |
def __init__(self,
|
| 21 |
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
| 22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
modules/whisper/whisper_Inference.py
CHANGED
|
@@ -8,11 +8,11 @@ import os
|
|
| 8 |
from argparse import Namespace
|
| 9 |
|
| 10 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
| 11 |
-
from modules.whisper.
|
| 12 |
from modules.whisper.data_classes import *
|
| 13 |
|
| 14 |
|
| 15 |
-
class WhisperInference(
|
| 16 |
def __init__(self,
|
| 17 |
model_dir: str = WHISPER_MODELS_DIR,
|
| 18 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
|
| 8 |
from argparse import Namespace
|
| 9 |
|
| 10 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
| 11 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
| 12 |
from modules.whisper.data_classes import *
|
| 13 |
|
| 14 |
|
| 15 |
+
class WhisperInference(BaseTranscriptionPipeline):
|
| 16 |
def __init__(self,
|
| 17 |
model_dir: str = WHISPER_MODELS_DIR,
|
| 18 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
modules/whisper/whisper_factory.py
CHANGED
|
@@ -6,7 +6,7 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
|
|
| 6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
| 7 |
from modules.whisper.whisper_Inference import WhisperInference
|
| 8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
| 9 |
-
from modules.whisper.
|
| 10 |
from modules.whisper.data_classes import *
|
| 11 |
|
| 12 |
|
|
@@ -20,7 +20,7 @@ class WhisperFactory:
|
|
| 20 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
| 21 |
uvr_model_dir: str = UVR_MODELS_DIR,
|
| 22 |
output_dir: str = OUTPUT_DIR,
|
| 23 |
-
) -> "
|
| 24 |
"""
|
| 25 |
Create a whisper inference class based on the provided whisper_type.
|
| 26 |
|
|
@@ -46,7 +46,7 @@ class WhisperFactory:
|
|
| 46 |
|
| 47 |
Returns
|
| 48 |
-------
|
| 49 |
-
|
| 50 |
An instance of the appropriate whisper inference class based on the whisper_type.
|
| 51 |
"""
|
| 52 |
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
|
|
|
| 6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
| 7 |
from modules.whisper.whisper_Inference import WhisperInference
|
| 8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
| 9 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
| 10 |
from modules.whisper.data_classes import *
|
| 11 |
|
| 12 |
|
|
|
|
| 20 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
| 21 |
uvr_model_dir: str = UVR_MODELS_DIR,
|
| 22 |
output_dir: str = OUTPUT_DIR,
|
| 23 |
+
) -> "BaseTranscriptionPipeline":
|
| 24 |
"""
|
| 25 |
Create a whisper inference class based on the provided whisper_type.
|
| 26 |
|
|
|
|
| 46 |
|
| 47 |
Returns
|
| 48 |
-------
|
| 49 |
+
BaseTranscriptionPipeline
|
| 50 |
An instance of the appropriate whisper inference class based on the whisper_type.
|
| 51 |
"""
|
| 52 |
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|