jhj0517 commited on
Commit
a3de454
·
1 Parent(s): a85d7d2

Rename class & file

Browse files
modules/whisper/{whisper_base.py → base_transcription_pipeline.py} RENAMED
@@ -24,7 +24,7 @@ from modules.diarize.diarizer import Diarizer
24
  from modules.vad.silero_vad import SileroVAD
25
 
26
 
27
- class WhisperBase(ABC):
28
  def __init__(self,
29
  model_dir: str = WHISPER_MODELS_DIR,
30
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -464,7 +464,7 @@ class WhisperBase(ABC):
464
  if torch.cuda.is_available():
465
  return "cuda"
466
  elif torch.backends.mps.is_available():
467
- if not WhisperBase.is_sparse_api_supported():
468
  # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
469
  return "cpu"
470
  return "mps"
 
24
  from modules.vad.silero_vad import SileroVAD
25
 
26
 
27
+ class BaseTranscriptionPipeline(ABC):
28
  def __init__(self,
29
  model_dir: str = WHISPER_MODELS_DIR,
30
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
464
  if torch.cuda.is_available():
465
  return "cuda"
466
  elif torch.backends.mps.is_available():
467
+ if not BaseTranscriptionPipeline.is_sparse_api_supported():
468
  # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
469
  return "cpu"
470
  return "mps"
modules/whisper/faster_whisper_inference.py CHANGED
@@ -13,10 +13,10 @@ from argparse import Namespace
13
 
14
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.data_classes import *
16
- from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
- class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
  model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
13
 
14
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.data_classes import *
16
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
17
 
18
 
19
+ class FasterWhisperInference(BaseTranscriptionPipeline):
20
  def __init__(self,
21
  model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -13,10 +13,10 @@ from argparse import Namespace
13
 
14
  from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.data_classes import *
16
- from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
- class InsanelyFastWhisperInference(WhisperBase):
20
  def __init__(self,
21
  model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
13
 
14
  from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.data_classes import *
16
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
17
 
18
 
19
+ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
20
  def __init__(self,
21
  model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
modules/whisper/whisper_Inference.py CHANGED
@@ -8,11 +8,11 @@ import os
8
  from argparse import Namespace
9
 
10
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
11
- from modules.whisper.whisper_base import WhisperBase
12
  from modules.whisper.data_classes import *
13
 
14
 
15
- class WhisperInference(WhisperBase):
16
  def __init__(self,
17
  model_dir: str = WHISPER_MODELS_DIR,
18
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
8
  from argparse import Namespace
9
 
10
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
11
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
12
  from modules.whisper.data_classes import *
13
 
14
 
15
+ class WhisperInference(BaseTranscriptionPipeline):
16
  def __init__(self,
17
  model_dir: str = WHISPER_MODELS_DIR,
18
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
modules/whisper/whisper_factory.py CHANGED
@@ -6,7 +6,7 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.whisper_Inference import WhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
9
- from modules.whisper.whisper_base import WhisperBase
10
  from modules.whisper.data_classes import *
11
 
12
 
@@ -20,7 +20,7 @@ class WhisperFactory:
20
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
21
  uvr_model_dir: str = UVR_MODELS_DIR,
22
  output_dir: str = OUTPUT_DIR,
23
- ) -> "WhisperBase":
24
  """
25
  Create a whisper inference class based on the provided whisper_type.
26
 
@@ -46,7 +46,7 @@ class WhisperFactory:
46
 
47
  Returns
48
  -------
49
- WhisperBase
50
  An instance of the appropriate whisper inference class based on the whisper_type.
51
  """
52
  # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
 
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.whisper_Inference import WhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
9
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
10
  from modules.whisper.data_classes import *
11
 
12
 
 
20
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
21
  uvr_model_dir: str = UVR_MODELS_DIR,
22
  output_dir: str = OUTPUT_DIR,
23
+ ) -> "BaseTranscriptionPipeline":
24
  """
25
  Create a whisper inference class based on the provided whisper_type.
26
 
 
46
 
47
  Returns
48
  -------
49
+ BaseTranscriptionPipeline
50
  An instance of the appropriate whisper inference class based on the whisper_type.
51
  """
52
  # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144