jhj0517 commited on
Commit
50e9f88
·
1 Parent(s): bc6b2e9

Get compute types with ctranslate2

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -35,8 +35,6 @@ class FasterWhisperInference(WhisperBase):
35
  self.model_paths = self.get_model_paths()
36
  self.device = self.get_device()
37
  self.available_models = self.model_paths.keys()
38
- self.available_compute_types = ctranslate2.get_supported_compute_types(
39
- "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
 
35
  self.model_paths = self.get_model_paths()
36
  self.device = self.get_device()
37
  self.available_models = self.model_paths.keys()
 
 
38
 
39
  def transcribe(self,
40
  audio: Union[str, BinaryIO, np.ndarray],
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -35,7 +35,6 @@ class InsanelyFastWhisperInference(WhisperBase):
35
  openai_models = whisper.available_models()
36
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
37
  self.available_models = openai_models + distil_models
38
- self.available_compute_types = ["float16"]
39
 
40
  def transcribe(self,
41
  audio: Union[str, np.ndarray, torch.Tensor],
 
35
  openai_models = whisper.available_models()
36
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
37
  self.available_models = openai_models + distil_models
 
38
 
39
  def transcribe(self,
40
  audio: Union[str, np.ndarray, torch.Tensor],
modules/whisper/whisper_base.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  import whisper
 
4
  import gradio as gr
5
  import torchaudio
6
  from abc import ABC, abstractmethod
@@ -47,8 +48,8 @@ class WhisperBase(ABC):
47
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
48
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
49
  self.device = self.get_device()
50
- self.available_compute_types = ["float16", "float32"]
51
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
52
 
53
  @abstractmethod
54
  def transcribe(self,
@@ -371,6 +372,18 @@ class WhisperBase(ABC):
371
  finally:
372
  self.release_cuda_memory()
373
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  @staticmethod
375
  def generate_and_write_file(file_name: str,
376
  transcribed_segments: list,
 
1
  import os
2
  import torch
3
  import whisper
4
+ import ctranslate2
5
  import gradio as gr
6
  import torchaudio
7
  from abc import ABC, abstractmethod
 
48
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
49
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
50
  self.device = self.get_device()
51
+ self.available_compute_types = self.get_available_compute_type()
52
+ self.current_compute_type = self.get_compute_type()
53
 
54
  @abstractmethod
55
  def transcribe(self,
 
372
  finally:
373
  self.release_cuda_memory()
374
 
375
+ def get_compute_type(self):
376
+ if "float16" in self.available_compute_types:
377
+ return "float16"
378
+ else:
379
+ return self.available_compute_types[0]
380
+
381
+ def get_available_compute_type(self):
382
+ if self.device == "cuda":
383
+ return ctranslate2.get_supported_compute_types("cuda")
384
+ else:
385
+ return ctranslate2.get_supported_compute_types("cpu")
386
+
387
  @staticmethod
388
  def generate_and_write_file(file_name: str,
389
  transcribed_segments: list,