Update whisper_cs.py

#34
by ssolito - opened
Files changed (1) hide show
  1. whisper_cs.py +6 -6
whisper_cs.py CHANGED
@@ -14,6 +14,7 @@ torch_dtype = torch.float32
14
  MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs"
15
  MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
  def clean_text(input_text):
19
  remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
@@ -104,21 +105,20 @@ def cleanup_temp_files(*file_paths):
104
  if path and os.path.exists(path):
105
  os.remove(path)
106
 
107
- if torch.cuda.is_available():
108
  faster_model = WhisperModel(
109
  MODEL_PATH_V2_FAST,
110
- device="cuda",
111
- compute_type="float16" #"int8_float16"
112
  )
113
- else:
 
114
  faster_model = WhisperModel(
115
  MODEL_PATH_V2_FAST,
116
  device="cpu",
117
  compute_type="int8"
118
  )
119
 
120
- faster_model = WhisperModel(MODEL_PATH_V2_FAST, device=DEVICE, compute_type="float16")
121
-
122
  def load_whisper_model(model_path: str):
123
  device = "cuda" if torch.cuda.is_available() else "cpu"
124
  model = whisper_ts.load_model(model_path, device=device)
 
14
  MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs"
15
  MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print("[INFO] CUDA available:", torch.cuda.is_available())
18
 
19
  def clean_text(input_text):
20
  remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
 
105
  if path and os.path.exists(path):
106
  os.remove(path)
107
 
108
+ try:
109
  faster_model = WhisperModel(
110
  MODEL_PATH_V2_FAST,
111
+ device="cuda" if torch.cuda.is_available() else "cpu",
112
+ compute_type="float16" if torch.cuda.is_available() else "int8"
113
  )
114
+ except RuntimeError as e:
115
+ print(f"[WARNING] Failed to load model on GPU: {e}")
116
  faster_model = WhisperModel(
117
  MODEL_PATH_V2_FAST,
118
  device="cpu",
119
  compute_type="int8"
120
  )
121
 
 
 
122
  def load_whisper_model(model_path: str):
123
  device = "cuda" if torch.cuda.is_available() else "cpu"
124
  model = whisper_ts.load_model(model_path, device=device)