Update app.py
Browse files
app.py
CHANGED
|
@@ -18,13 +18,8 @@ def write_file(output_file, subtitle):
|
|
| 18 |
f.write(subtitle)
|
| 19 |
|
| 20 |
def create_pipe(model, flash):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
elif platform == "darwin":
|
| 24 |
-
device = "mps"
|
| 25 |
-
else:
|
| 26 |
-
device = "cpu"
|
| 27 |
-
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 28 |
model_id = model
|
| 29 |
|
| 30 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
@@ -33,13 +28,8 @@ def create_pipe(model, flash):
|
|
| 33 |
low_cpu_mem_usage=True,
|
| 34 |
use_safetensors=True,
|
| 35 |
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
|
| 36 |
-
# eager (manual attention implementation)
|
| 37 |
-
# flash_attention_2 (implementation using flash attention 2)
|
| 38 |
-
# sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
|
| 39 |
-
# PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
|
| 40 |
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 44 |
|
| 45 |
pipe = pipeline(
|
|
@@ -47,15 +37,25 @@ def create_pipe(model, flash):
|
|
| 47 |
model=model,
|
| 48 |
tokenizer=processor.tokenizer,
|
| 49 |
feature_extractor=processor.feature_extractor,
|
| 50 |
-
#
|
| 51 |
-
#
|
| 52 |
-
# batch_size=16,
|
| 53 |
-
torch_dtype=torch_dtype,
|
| 54 |
-
device=device,
|
| 55 |
)
|
| 56 |
-
return pipe
|
| 57 |
|
| 58 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
|
| 60 |
chunk_length_s, batch_size, progress=gr.Progress()):
|
| 61 |
global last_model
|
|
@@ -73,16 +73,24 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
|
|
| 73 |
if last_model is None:
|
| 74 |
logging.info("first model")
|
| 75 |
progress(0.1, desc="Loading Model..")
|
| 76 |
-
pipe = create_pipe(modelName, flash)
|
| 77 |
elif modelName != last_model:
|
| 78 |
logging.info("new model")
|
| 79 |
torch.cuda.empty_cache()
|
| 80 |
progress(0.1, desc="Loading Model..")
|
| 81 |
-
pipe = create_pipe(modelName, flash)
|
| 82 |
else:
|
| 83 |
logging.info("Model not changed")
|
|
|
|
| 84 |
last_model = modelName
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
srt_sub = Subtitle("srt")
|
| 87 |
vtt_sub = Subtitle("vtt")
|
| 88 |
txt_sub = Subtitle("txt")
|
|
@@ -176,4 +184,4 @@ with gr.Blocks(title="Insanely Fast Whisper") as demo:
|
|
| 176 |
)
|
| 177 |
|
| 178 |
if __name__ == "__main__":
|
| 179 |
-
demo.launch()
|
|
|
|
| 18 |
f.write(subtitle)
|
| 19 |
|
| 20 |
def create_pipe(model, flash):
|
| 21 |
+
# Load the model into RAM first
|
| 22 |
+
torch_dtype = torch.float32 # Load onto CPU with float32 precision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
model_id = model
|
| 24 |
|
| 25 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
|
|
| 28 |
low_cpu_mem_usage=True,
|
| 29 |
use_safetensors=True,
|
| 30 |
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
+
|
|
|
|
| 33 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 34 |
|
| 35 |
pipe = pipeline(
|
|
|
|
| 37 |
model=model,
|
| 38 |
tokenizer=processor.tokenizer,
|
| 39 |
feature_extractor=processor.feature_extractor,
|
| 40 |
+
torch_dtype=torch_dtype, # Keep in CPU until GPU is requested
|
| 41 |
+
device="cpu", # Initially stay on CPU
|
|
|
|
|
|
|
|
|
|
| 42 |
)
|
| 43 |
+
return pipe, model # Return both pipe and model for later GPU switch
|
| 44 |
|
| 45 |
@spaces.GPU
|
| 46 |
+
def move_to_gpu(model):
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
device = "cuda:0"
|
| 49 |
+
torch_dtype = torch.float16 # Use float16 precision on GPU
|
| 50 |
+
model.to(device, dtype=torch_dtype)
|
| 51 |
+
elif platform == "darwin":
|
| 52 |
+
device = "mps"
|
| 53 |
+
model.to(device)
|
| 54 |
+
else:
|
| 55 |
+
device = "cpu"
|
| 56 |
+
|
| 57 |
+
return device
|
| 58 |
+
|
| 59 |
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
|
| 60 |
chunk_length_s, batch_size, progress=gr.Progress()):
|
| 61 |
global last_model
|
|
|
|
| 73 |
if last_model is None:
|
| 74 |
logging.info("first model")
|
| 75 |
progress(0.1, desc="Loading Model..")
|
| 76 |
+
pipe, model = create_pipe(modelName, flash)
|
| 77 |
elif modelName != last_model:
|
| 78 |
logging.info("new model")
|
| 79 |
torch.cuda.empty_cache()
|
| 80 |
progress(0.1, desc="Loading Model..")
|
| 81 |
+
pipe, model = create_pipe(modelName, flash)
|
| 82 |
else:
|
| 83 |
logging.info("Model not changed")
|
| 84 |
+
|
| 85 |
last_model = modelName
|
| 86 |
|
| 87 |
+
# Now move the model to GPU after the pipe is created
|
| 88 |
+
device = move_to_gpu(pipe.model)
|
| 89 |
+
|
| 90 |
+
# Update pipe's device
|
| 91 |
+
pipe.device = torch.device(device)
|
| 92 |
+
pipe.model.to(pipe.device)
|
| 93 |
+
|
| 94 |
srt_sub = Subtitle("srt")
|
| 95 |
vtt_sub = Subtitle("vtt")
|
| 96 |
txt_sub = Subtitle("txt")
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
if __name__ == "__main__":
|
| 187 |
+
demo.launch()
|