jhj0517 commited on
Commit
ad742b8
·
1 Parent(s): 80493bc

commend get_device() on abstract function

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +15 -15
modules/whisper/whisper_base.py CHANGED
@@ -42,7 +42,7 @@ class WhisperBase(ABC):
42
  self.vad = SileroVAD()
43
 
44
  @abstractmethod
45
- #@spaces.GPU(duration=120)
46
  def transcribe(self,
47
  audio: Union[str, BinaryIO, np.ndarray],
48
  progress: gr.Progress,
@@ -51,7 +51,7 @@ class WhisperBase(ABC):
51
  pass
52
 
53
  @abstractmethod
54
- #@spaces.GPU(duration=120)
55
  def update_model(self,
56
  model_size: str,
57
  compute_type: str,
@@ -59,7 +59,7 @@ class WhisperBase(ABC):
59
  ):
60
  pass
61
 
62
- #@spaces.GPU(duration=120)
63
  def run(self,
64
  audio: Union[str, BinaryIO, np.ndarray],
65
  progress: gr.Progress,
@@ -196,7 +196,7 @@ class WhisperBase(ABC):
196
  if not files:
197
  self.remove_input_files([file.name for file in files])
198
 
199
- #@spaces.GPU(duration=120)
200
  def transcribe_mic(self,
201
  mic_audio: str,
202
  file_format: str,
@@ -249,7 +249,7 @@ class WhisperBase(ABC):
249
  self.release_cuda_memory()
250
  self.remove_input_files([mic_audio])
251
 
252
- #@spaces.GPU(duration=120)
253
  def transcribe_youtube(self,
254
  youtube_link: str,
255
  file_format: str,
@@ -399,18 +399,18 @@ class WhisperBase(ABC):
399
 
400
  return time_str.strip()
401
 
402
- @staticmethod
403
- #@spaces.GPU(duration=120)
404
- def get_device():
405
- if torch.cuda.is_available():
406
- return "cuda"
407
- elif torch.backends.mps.is_available():
408
- return "mps"
409
- else:
410
- return "cpu"
411
 
412
  @staticmethod
413
- #@spaces.GPU(duration=120)
414
  def release_cuda_memory():
415
  if torch.cuda.is_available():
416
  torch.cuda.empty_cache()
 
42
  self.vad = SileroVAD()
43
 
44
  @abstractmethod
45
+ @spaces.GPU(duration=120)
46
  def transcribe(self,
47
  audio: Union[str, BinaryIO, np.ndarray],
48
  progress: gr.Progress,
 
51
  pass
52
 
53
  @abstractmethod
54
+ @spaces.GPU(duration=120)
55
  def update_model(self,
56
  model_size: str,
57
  compute_type: str,
 
59
  ):
60
  pass
61
 
62
+ @spaces.GPU(duration=120)
63
  def run(self,
64
  audio: Union[str, BinaryIO, np.ndarray],
65
  progress: gr.Progress,
 
196
  if not files:
197
  self.remove_input_files([file.name for file in files])
198
 
199
+ @spaces.GPU(duration=120)
200
  def transcribe_mic(self,
201
  mic_audio: str,
202
  file_format: str,
 
249
  self.release_cuda_memory()
250
  self.remove_input_files([mic_audio])
251
 
252
+ @spaces.GPU(duration=120)
253
  def transcribe_youtube(self,
254
  youtube_link: str,
255
  file_format: str,
 
399
 
400
  return time_str.strip()
401
 
402
+ # @staticmethod
403
+ # @spaces.GPU(duration=120)
404
+ # def get_device():
405
+ # if torch.cuda.is_available():
406
+ # return "cuda"
407
+ # elif torch.backends.mps.is_available():
408
+ # return "mps"
409
+ # else:
410
+ # return "cpu"
411
 
412
  @staticmethod
413
+ @spaces.GPU(duration=120)
414
  def release_cuda_memory():
415
  if torch.cuda.is_available():
416
  torch.cuda.empty_cache()