reedmayhew commited on
Commit
97bdbba
·
verified ·
1 Parent(s): 5a0e518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -52
app.py CHANGED
@@ -42,13 +42,20 @@ def create_pipe(model, flash):
42
  )
43
  return pipe, model # Return both pipe and model for later GPU switch
44
 
45
- @spaces.GPU(duration=120)
46
  def move_to_gpu(model):
47
- device = "cuda:0"
48
- torch_dtype = torch.float16 # Use float16 precision on GPU
49
- model.to(device, dtype=torch_dtype)
 
 
 
 
 
 
 
50
  return device
51
 
 
52
  def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
53
  chunk_length_s, batch_size, progress=gr.Progress()):
54
  global last_model
@@ -77,54 +84,55 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
77
 
78
  last_model = modelName
79
 
80
- # Now move the model to GPU after the pipe is created
81
- device = move_to_gpu(pipe.model)
82
-
83
- # Update pipe's device
84
- pipe.device = torch.device(device)
85
- pipe.model.to(pipe.device)
86
-
87
- srt_sub = Subtitle("srt")
88
- vtt_sub = Subtitle("vtt")
89
- txt_sub = Subtitle("txt")
90
-
91
- files = []
92
- if multipleFiles:
93
- files += multipleFiles
94
- if urlData:
95
- files.append(urlData)
96
- if microphoneData:
97
- files.append(microphoneData)
98
- logging.info(files)
99
-
100
- generate_kwargs = {}
101
- if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
102
- generate_kwargs["language"] = languageName
103
- if modelName.endswith(".en") == False:
104
- generate_kwargs["task"] = task
105
-
106
- files_out = []
107
- for file in progress.tqdm(files, desc="Working..."):
108
- start_time = time.time()
109
- logging.info(file)
110
- outputs = pipe(
111
- file,
112
- chunk_length_s=chunk_length_s, # 30
113
- batch_size=batch_size, # 24
114
- generate_kwargs=generate_kwargs,
115
- return_timestamps=True,
116
- )
117
- logging.debug(outputs)
118
- logging.info(print(f"transcribe: {time.time() - start_time} sec."))
119
-
120
- file_out = file.split('/')[-1]
121
- srt = srt_sub.get_subtitle(outputs["chunks"])
122
- vtt = vtt_sub.get_subtitle(outputs["chunks"])
123
- txt = txt_sub.get_subtitle(outputs["chunks"])
124
- write_file(file_out + ".srt", srt)
125
- write_file(file_out + ".vtt", vtt)
126
- write_file(file_out + ".txt", txt)
127
- files_out += [file_out + ".srt", file_out + ".vtt", file_out + ".txt"]
 
128
 
129
  progress(1, desc="Completed!")
130
 
 
42
  )
43
  return pipe, model # Return both pipe and model for later GPU switch
44
 
 
45
  def move_to_gpu(model):
46
+ if torch.cuda.is_available():
47
+ device = "cuda:0"
48
+ torch_dtype = torch.float16 # Use float16 precision on GPU
49
+ model.to(device, dtype=torch_dtype)
50
+ elif platform == "darwin":
51
+ device = "mps"
52
+ model.to(device)
53
+ else:
54
+ device = "cpu"
55
+
56
  return device
57
 
58
+ @spaces.GPU
59
  def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
60
  chunk_length_s, batch_size, progress=gr.Progress()):
61
  global last_model
 
84
 
85
  last_model = modelName
86
 
87
+ # Now move the model to GPU after the pipe is created, within the function's context
88
+ with torch.inference_mode():
89
+ device = move_to_gpu(pipe.model)
90
+
91
+ # Update pipe's device
92
+ pipe.device = torch.device(device)
93
+ pipe.model.to(pipe.device)
94
+
95
+ srt_sub = Subtitle("srt")
96
+ vtt_sub = Subtitle("vtt")
97
+ txt_sub = Subtitle("txt")
98
+
99
+ files = []
100
+ if multipleFiles:
101
+ files += multipleFiles
102
+ if urlData:
103
+ files.append(urlData)
104
+ if microphoneData:
105
+ files.append(microphoneData)
106
+ logging.info(files)
107
+
108
+ generate_kwargs = {}
109
+ if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
110
+ generate_kwargs["language"] = languageName
111
+ if modelName.endswith(".en") == False:
112
+ generate_kwargs["task"] = task
113
+
114
+ files_out = []
115
+ for file in progress.tqdm(files, desc="Working..."):
116
+ start_time = time.time()
117
+ logging.info(file)
118
+ outputs = pipe(
119
+ file,
120
+ chunk_length_s=chunk_length_s, # 30
121
+ batch_size=batch_size, # 24
122
+ generate_kwargs=generate_kwargs,
123
+ return_timestamps=True,
124
+ )
125
+ logging.debug(outputs)
126
+ logging.info(print(f"transcribe: {time.time() - start_time} sec."))
127
+
128
+ file_out = file.split('/')[-1]
129
+ srt = srt_sub.get_subtitle(outputs["chunks"])
130
+ vtt = vtt_sub.get_subtitle(outputs["chunks"])
131
+ txt = txt_sub.get_subtitle(outputs["chunks"])
132
+ write_file(file_out + ".srt", srt)
133
+ write_file(file_out + ".vtt", vtt)
134
+ write_file(file_out + ".txt", txt)
135
+ files_out += [file_out + ".srt", file_out + ".vtt", file_out + ".txt"]
136
 
137
  progress(1, desc="Completed!")
138