reedmayhew commited on
Commit
9b8d36a
·
verified ·
1 Parent(s): 8ea6044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -33
app.py CHANGED
@@ -8,15 +8,15 @@ from transformers.utils import is_flash_attn_2_available
8
  from languages import get_language_names
9
  from subtitle_manager import Subtitle
10
 
11
-
12
  logging.basicConfig(level=logging.INFO)
13
  last_model = None
14
  pipe = None
15
 
16
- def write_file(output_file,subtitle):
17
  with open(output_file, 'w', encoding='utf-8') as f:
18
  f.write(subtitle)
19
 
 
20
  def create_pipe(model, flash):
21
  if torch.cuda.is_available():
22
  device = "cuda:0"
@@ -55,8 +55,9 @@ def create_pipe(model, flash):
55
  )
56
  return pipe
57
 
 
58
  def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
59
- chunk_length_s, batch_size, progress=gr.Progress()):
60
  global last_model
61
  global pipe
62
 
@@ -69,7 +70,7 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
69
  logging.info(f"chunk_length_s: {chunk_length_s}")
70
  logging.info(f"batch_size: {batch_size}")
71
 
72
- if last_model == None:
73
  logging.info("first model")
74
  progress(0.1, desc="Loading Model..")
75
  pipe = create_pipe(modelName, flash)
@@ -88,7 +89,7 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
88
 
89
  files = []
90
  if multipleFiles:
91
- files+=multipleFiles
92
  if urlData:
93
  files.append(urlData)
94
  if microphoneData:
@@ -107,8 +108,8 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
107
  logging.info(file)
108
  outputs = pipe(
109
  file,
110
- chunk_length_s=chunk_length_s,#30
111
- batch_size=batch_size,#24
112
  generate_kwargs=generate_kwargs,
113
  return_timestamps=True,
114
  )
@@ -119,13 +120,13 @@ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleF
119
  srt = srt_sub.get_subtitle(outputs["chunks"])
120
  vtt = vtt_sub.get_subtitle(outputs["chunks"])
121
  txt = txt_sub.get_subtitle(outputs["chunks"])
122
- write_file(file_out+".srt",srt)
123
- write_file(file_out+".vtt",vtt)
124
- write_file(file_out+".txt",txt)
125
- files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]
126
 
127
  progress(1, desc="Completed!")
128
-
129
  return files_out, vtt, txt
130
 
131
 
@@ -142,7 +143,7 @@ with gr.Blocks(title="Insanely Fast Whisper") as demo:
142
  "openai/whisper-large-v2", "distil-whisper/distil-large-v2",
143
  "openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
144
  ]
145
- waveform_options=gr.WaveformOptions(
146
  waveform_color="#01C6FF",
147
  waveform_progress_color="#0066B4",
148
  skip_length=2,
@@ -150,25 +151,29 @@ with gr.Blocks(title="Insanely Fast Whisper") as demo:
150
  )
151
 
152
  simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
153
- description=description,
154
- article=article,
155
- inputs=[
156
- gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,),
157
- gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,),
158
- gr.Text(label="URL", info="(YouTube, etc.)", interactive = True),
159
- gr.File(label="Upload Files", file_count="multiple"),
160
- gr.Audio(sources=["upload", "microphone",], type="filepath", label="Input", waveform_options = waveform_options),
161
- gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True),
162
- gr.Checkbox(label='Flash',info='Use Flash Attention 2'),
163
- gr.Number(label='chunk_length_s',value=30, interactive = True),
164
- gr.Number(label='batch_size',value=24, interactive = True)
165
- ], outputs=[
166
- gr.File(label="Download"),
167
- gr.Text(label="Transcription"),
168
- gr.Text(label="Segments")
169
- ]
170
- )
 
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
- demo.launch()
174
-
 
8
  from languages import get_language_names
9
  from subtitle_manager import Subtitle
10
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  last_model = None
13
  pipe = None
14
 
15
+ def write_file(output_file, subtitle):
16
  with open(output_file, 'w', encoding='utf-8') as f:
17
  f.write(subtitle)
18
 
19
+ @spaces.GPU
20
  def create_pipe(model, flash):
21
  if torch.cuda.is_available():
22
  device = "cuda:0"
 
55
  )
56
  return pipe
57
 
58
+ @spaces.GPU
59
  def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
60
+ chunk_length_s, batch_size, progress=gr.Progress()):
61
  global last_model
62
  global pipe
63
 
 
70
  logging.info(f"chunk_length_s: {chunk_length_s}")
71
  logging.info(f"batch_size: {batch_size}")
72
 
73
+ if last_model is None:
74
  logging.info("first model")
75
  progress(0.1, desc="Loading Model..")
76
  pipe = create_pipe(modelName, flash)
 
89
 
90
  files = []
91
  if multipleFiles:
92
+ files += multipleFiles
93
  if urlData:
94
  files.append(urlData)
95
  if microphoneData:
 
108
  logging.info(file)
109
  outputs = pipe(
110
  file,
111
+ chunk_length_s=chunk_length_s, # 30
112
+ batch_size=batch_size, # 24
113
  generate_kwargs=generate_kwargs,
114
  return_timestamps=True,
115
  )
 
120
  srt = srt_sub.get_subtitle(outputs["chunks"])
121
  vtt = vtt_sub.get_subtitle(outputs["chunks"])
122
  txt = txt_sub.get_subtitle(outputs["chunks"])
123
+ write_file(file_out + ".srt", srt)
124
+ write_file(file_out + ".vtt", vtt)
125
+ write_file(file_out + ".txt", txt)
126
+ files_out += [file_out + ".srt", file_out + ".vtt", file_out + ".txt"]
127
 
128
  progress(1, desc="Completed!")
129
+
130
  return files_out, vtt, txt
131
 
132
 
 
143
  "openai/whisper-large-v2", "distil-whisper/distil-large-v2",
144
  "openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
145
  ]
146
+ waveform_options = gr.WaveformOptions(
147
  waveform_color="#01C6FF",
148
  waveform_progress_color="#0066B4",
149
  skip_length=2,
 
151
  )
152
 
153
  simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
154
+ description=description,
155
+ article=article,
156
+ inputs=[
157
+ gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2",
158
+ label="Model", info="Select whisper model", interactive=True, ),
159
+ gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()),
160
+ value="Automatic Detection", label="Language",
161
+ info="Select audio voice language", interactive=True, ),
162
+ gr.Text(label="URL", info="(YouTube, etc.)", interactive=True),
163
+ gr.File(label="Upload Files", file_count="multiple"),
164
+ gr.Audio(sources=["upload", "microphone", ], type="filepath", label="Input",
165
+ waveform_options=waveform_options),
166
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task",
167
+ value="transcribe", interactive=True),
168
+ gr.Checkbox(label='Flash', info='Use Flash Attention 2'),
169
+ gr.Number(label='chunk_length_s', value=30, interactive=True),
170
+ gr.Number(label='batch_size', value=24, interactive=True)
171
+ ], outputs=[
172
+ gr.File(label="Download"),
173
+ gr.Text(label="Transcription"),
174
+ gr.Text(label="Segments")
175
+ ]
176
+ )
177
 
178
  if __name__ == "__main__":
179
+ demo.launch()