reedmayhew commited on
Commit
5d0e917
·
verified ·
1 Parent(s): d4a3a9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -18,13 +18,8 @@ def write_file(output_file, subtitle):
18
  f.write(subtitle)
19
 
20
  def create_pipe(model, flash):
21
- if torch.cuda.is_available():
22
- device = "cuda:0"
23
- elif platform == "darwin":
24
- device = "mps"
25
- else:
26
- device = "cpu"
27
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
28
  model_id = model
29
 
30
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
@@ -33,13 +28,8 @@ def create_pipe(model, flash):
33
  low_cpu_mem_usage=True,
34
  use_safetensors=True,
35
  attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
36
- # eager (manual attention implementation)
37
- # flash_attention_2 (implementation using flash attention 2)
38
- # sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
39
- # PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
40
  )
41
- model.to(device)
42
-
43
  processor = AutoProcessor.from_pretrained(model_id)
44
 
45
  pipe = pipeline(
@@ -47,15 +37,25 @@ def create_pipe(model, flash):
47
  model=model,
48
  tokenizer=processor.tokenizer,
49
  feature_extractor=processor.feature_extractor,
50
- # max_new_tokens=128,
51
- # chunk_length_s=15,
52
- # batch_size=16,
53
- torch_dtype=torch_dtype,
54
- device=device,
55
  )
56
- return pipe
57
 
58
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
60
  chunk_length_s, batch_size, progress=gr.Progress()):
61
  global last_model
@@ -73,16 +73,24 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
73
  if last_model is None:
74
  logging.info("first model")
75
  progress(0.1, desc="Loading Model..")
76
- pipe = create_pipe(modelName, flash)
77
  elif modelName != last_model:
78
  logging.info("new model")
79
  torch.cuda.empty_cache()
80
  progress(0.1, desc="Loading Model..")
81
- pipe = create_pipe(modelName, flash)
82
  else:
83
  logging.info("Model not changed")
 
84
  last_model = modelName
85
 
 
 
 
 
 
 
 
86
  srt_sub = Subtitle("srt")
87
  vtt_sub = Subtitle("vtt")
88
  txt_sub = Subtitle("txt")
@@ -176,4 +184,4 @@ with gr.Blocks(title="Insanely Fast Whisper") as demo:
176
  )
177
 
178
  if __name__ == "__main__":
179
- demo.launch()
 
18
  f.write(subtitle)
19
 
20
  def create_pipe(model, flash):
21
+ # Load the model into RAM first
22
+ torch_dtype = torch.float32 # Load onto CPU with float32 precision
 
 
 
 
 
23
  model_id = model
24
 
25
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
 
28
  low_cpu_mem_usage=True,
29
  use_safetensors=True,
30
  attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
 
 
 
 
31
  )
32
+
 
33
  processor = AutoProcessor.from_pretrained(model_id)
34
 
35
  pipe = pipeline(
 
37
  model=model,
38
  tokenizer=processor.tokenizer,
39
  feature_extractor=processor.feature_extractor,
40
+ torch_dtype=torch_dtype, # Keep in CPU until GPU is requested
41
+ device="cpu", # Initially stay on CPU
 
 
 
42
  )
43
+ return pipe, model # Return both pipe and model for later GPU switch
44
 
45
  @spaces.GPU
46
+ def move_to_gpu(model):
47
+ if torch.cuda.is_available():
48
+ device = "cuda:0"
49
+ torch_dtype = torch.float16 # Use float16 precision on GPU
50
+ model.to(device, dtype=torch_dtype)
51
+ elif platform == "darwin":
52
+ device = "mps"
53
+ model.to(device)
54
+ else:
55
+ device = "cpu"
56
+
57
+ return device
58
+
59
  def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
60
  chunk_length_s, batch_size, progress=gr.Progress()):
61
  global last_model
 
73
  if last_model is None:
74
  logging.info("first model")
75
  progress(0.1, desc="Loading Model..")
76
+ pipe, model = create_pipe(modelName, flash)
77
  elif modelName != last_model:
78
  logging.info("new model")
79
  torch.cuda.empty_cache()
80
  progress(0.1, desc="Loading Model..")
81
+ pipe, model = create_pipe(modelName, flash)
82
  else:
83
  logging.info("Model not changed")
84
+
85
  last_model = modelName
86
 
87
+ # Now move the model to GPU after the pipe is created
88
+ device = move_to_gpu(pipe.model)
89
+
90
+ # Update pipe's device
91
+ pipe.device = torch.device(device)
92
+ pipe.model.to(pipe.device)
93
+
94
  srt_sub = Subtitle("srt")
95
  vtt_sub = Subtitle("vtt")
96
  txt_sub = Subtitle("txt")
 
184
  )
185
 
186
  if __name__ == "__main__":
187
+ demo.launch()