Spaces:
Running
Running
jhj0517
commited on
Commit
·
a3de454
1
Parent(s):
a85d7d2
Rename class & file
Browse files
modules/whisper/{whisper_base.py → base_transcription_pipeline.py}
RENAMED
@@ -24,7 +24,7 @@ from modules.diarize.diarizer import Diarizer
|
|
24 |
from modules.vad.silero_vad import SileroVAD
|
25 |
|
26 |
|
27 |
-
class
|
28 |
def __init__(self,
|
29 |
model_dir: str = WHISPER_MODELS_DIR,
|
30 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
@@ -464,7 +464,7 @@ class WhisperBase(ABC):
|
|
464 |
if torch.cuda.is_available():
|
465 |
return "cuda"
|
466 |
elif torch.backends.mps.is_available():
|
467 |
-
if not
|
468 |
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
469 |
return "cpu"
|
470 |
return "mps"
|
|
|
24 |
from modules.vad.silero_vad import SileroVAD
|
25 |
|
26 |
|
27 |
+
class BaseTranscriptionPipeline(ABC):
|
28 |
def __init__(self,
|
29 |
model_dir: str = WHISPER_MODELS_DIR,
|
30 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
464 |
if torch.cuda.is_available():
|
465 |
return "cuda"
|
466 |
elif torch.backends.mps.is_available():
|
467 |
+
if not BaseTranscriptionPipeline.is_sparse_api_supported():
|
468 |
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
469 |
return "cpu"
|
470 |
return "mps"
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -13,10 +13,10 @@ from argparse import Namespace
|
|
13 |
|
14 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
from modules.whisper.data_classes import *
|
16 |
-
from modules.whisper.
|
17 |
|
18 |
|
19 |
-
class FasterWhisperInference(
|
20 |
def __init__(self,
|
21 |
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
13 |
|
14 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
from modules.whisper.data_classes import *
|
16 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
17 |
|
18 |
|
19 |
+
class FasterWhisperInference(BaseTranscriptionPipeline):
|
20 |
def __init__(self,
|
21 |
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -13,10 +13,10 @@ from argparse import Namespace
|
|
13 |
|
14 |
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
from modules.whisper.data_classes import *
|
16 |
-
from modules.whisper.
|
17 |
|
18 |
|
19 |
-
class InsanelyFastWhisperInference(
|
20 |
def __init__(self,
|
21 |
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
13 |
|
14 |
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
from modules.whisper.data_classes import *
|
16 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
17 |
|
18 |
|
19 |
+
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
|
20 |
def __init__(self,
|
21 |
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
modules/whisper/whisper_Inference.py
CHANGED
@@ -8,11 +8,11 @@ import os
|
|
8 |
from argparse import Namespace
|
9 |
|
10 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
11 |
-
from modules.whisper.
|
12 |
from modules.whisper.data_classes import *
|
13 |
|
14 |
|
15 |
-
class WhisperInference(
|
16 |
def __init__(self,
|
17 |
model_dir: str = WHISPER_MODELS_DIR,
|
18 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
8 |
from argparse import Namespace
|
9 |
|
10 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
11 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
12 |
from modules.whisper.data_classes import *
|
13 |
|
14 |
|
15 |
+
class WhisperInference(BaseTranscriptionPipeline):
|
16 |
def __init__(self,
|
17 |
model_dir: str = WHISPER_MODELS_DIR,
|
18 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
modules/whisper/whisper_factory.py
CHANGED
@@ -6,7 +6,7 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
|
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.whisper_Inference import WhisperInference
|
8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
9 |
-
from modules.whisper.
|
10 |
from modules.whisper.data_classes import *
|
11 |
|
12 |
|
@@ -20,7 +20,7 @@ class WhisperFactory:
|
|
20 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
21 |
uvr_model_dir: str = UVR_MODELS_DIR,
|
22 |
output_dir: str = OUTPUT_DIR,
|
23 |
-
) -> "
|
24 |
"""
|
25 |
Create a whisper inference class based on the provided whisper_type.
|
26 |
|
@@ -46,7 +46,7 @@ class WhisperFactory:
|
|
46 |
|
47 |
Returns
|
48 |
-------
|
49 |
-
|
50 |
An instance of the appropriate whisper inference class based on the whisper_type.
|
51 |
"""
|
52 |
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
|
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.whisper_Inference import WhisperInference
|
8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
9 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
10 |
from modules.whisper.data_classes import *
|
11 |
|
12 |
|
|
|
20 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
21 |
uvr_model_dir: str = UVR_MODELS_DIR,
|
22 |
output_dir: str = OUTPUT_DIR,
|
23 |
+
) -> "BaseTranscriptionPipeline":
|
24 |
"""
|
25 |
Create a whisper inference class based on the provided whisper_type.
|
26 |
|
|
|
46 |
|
47 |
Returns
|
48 |
-------
|
49 |
+
BaseTranscriptionPipeline
|
50 |
An instance of the appropriate whisper inference class based on the whisper_type.
|
51 |
"""
|
52 |
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|