dineshb commited on
Commit
576a0b9
·
1 Parent(s): 9efd378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -14
app.py CHANGED
@@ -3,12 +3,12 @@ import torch
3
  import gradio as gr
4
  import pytube as pt
5
  from transformers import pipeline
6
- from huggingface_hub import model_info
7
 
8
- MODEL_NAME = "openai/whisper-medium" #this always needs to stay in line 8 :D sorry for the hackiness
9
- lang = "en"
10
 
11
  device = 0 if torch.cuda.is_available() else "cpu"
 
12
  pipe = pipeline(
13
  task="automatic-speech-recognition",
14
  model=MODEL_NAME,
@@ -16,9 +16,13 @@ pipe = pipeline(
16
  device=device,
17
  )
18
 
19
- pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
20
 
21
- def transcribe(microphone, file_upload):
 
 
 
 
 
22
  warn_output = ""
23
  if (microphone is not None) and (file_upload is not None):
24
  warn_output = (
@@ -31,7 +35,9 @@ def transcribe(microphone, file_upload):
31
 
32
  file = microphone if microphone is not None else file_upload
33
 
34
- text = pipe(file)["text"]
 
 
35
 
36
  return warn_output + text
37
 
@@ -45,13 +51,15 @@ def _return_yt_html_embed(yt_url):
45
  return HTML_str
46
 
47
 
48
- def yt_transcribe(yt_url):
49
  yt = pt.YouTube(yt_url)
50
  html_embed_str = _return_yt_html_embed(yt_url)
51
  stream = yt.streams.filter(only_audio=True)[0]
52
  stream.download(filename="audio.mp3")
53
 
54
- text = pipe("audio.mp3")["text"]
 
 
55
 
56
  return html_embed_str, text
57
 
@@ -63,26 +71,34 @@ mf_transcribe = gr.Interface(
63
  inputs=[
64
  gr.inputs.Audio(source="microphone", type="filepath", optional=True),
65
  gr.inputs.Audio(source="upload", type="filepath", optional=True),
 
66
  ],
67
  outputs="text",
68
  layout="horizontal",
69
  theme="huggingface",
70
- title="Speech to Text using Open AI",
71
  description=(
72
- "Transcribe long-form microphone or audio inputs with the click of a button!"
 
 
73
  ),
74
  allow_flagging="never",
75
  )
76
 
77
  yt_transcribe = gr.Interface(
78
  fn=yt_transcribe,
79
- inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
 
 
 
80
  outputs=["html", "text"],
81
  layout="horizontal",
82
  theme="huggingface",
83
- title="Speech to Text using Open AI Transcribe YouTube",
84
  description=(
85
- "Transcribe long-form YouTube videos with the click of a button! "
 
 
86
  ),
87
  allow_flagging="never",
88
  )
@@ -90,4 +106,4 @@ yt_transcribe = gr.Interface(
90
  with demo:
91
  gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Transcribe Audio", "Transcribe YouTube"])
92
 
93
- demo.launch(enable_queue=True)
 
3
  import gradio as gr
4
  import pytube as pt
5
  from transformers import pipeline
 
6
 
7
+ MODEL_NAME = "openai/whisper-medium"
8
+ BATCH_SIZE = 8
9
 
10
  device = 0 if torch.cuda.is_available() else "cpu"
11
+
12
  pipe = pipeline(
13
  task="automatic-speech-recognition",
14
  model=MODEL_NAME,
 
16
  device=device,
17
  )
18
 
 
19
 
20
+ all_special_ids = pipe.tokenizer.all_special_ids
21
+ transcribe_token_id = all_special_ids[-5]
22
+ translate_token_id = all_special_ids[-6]
23
+
24
+
25
+ def transcribe(microphone, file_upload, task):
26
  warn_output = ""
27
  if (microphone is not None) and (file_upload is not None):
28
  warn_output = (
 
35
 
36
  file = microphone if microphone is not None else file_upload
37
 
38
+ pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id if task=="transcribe" else translate_token_id]]
39
+
40
+ text = pipe(file, batch_size=BATCH_SIZE)["text"]
41
 
42
  return warn_output + text
43
 
 
51
  return HTML_str
52
 
53
 
54
+ def yt_transcribe(yt_url, task):
55
  yt = pt.YouTube(yt_url)
56
  html_embed_str = _return_yt_html_embed(yt_url)
57
  stream = yt.streams.filter(only_audio=True)[0]
58
  stream.download(filename="audio.mp3")
59
 
60
+ pipe.model.config.forced_decoder_ids = [[2, transcribe_token_id if task=="transcribe" else translate_token_id]]
61
+
62
+ text = pipe("audio.mp3", batch_size=BATCH_SIZE)["text"]
63
 
64
  return html_embed_str, text
65
 
 
71
  inputs=[
72
  gr.inputs.Audio(source="microphone", type="filepath", optional=True),
73
  gr.inputs.Audio(source="upload", type="filepath", optional=True),
74
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
75
  ],
76
  outputs="text",
77
  layout="horizontal",
78
  theme="huggingface",
79
+ title="Whisper Large V2: Transcribe Audio",
80
  description=(
81
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
82
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
83
+ " of arbitrary length."
84
  ),
85
  allow_flagging="never",
86
  )
87
 
88
  yt_transcribe = gr.Interface(
89
  fn=yt_transcribe,
90
+ inputs=[
91
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
92
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe")
93
+ ],
94
  outputs=["html", "text"],
95
  layout="horizontal",
96
  theme="huggingface",
97
+ title="Whisper Large V2: Transcribe YouTube",
98
  description=(
99
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the checkpoint"
100
+ f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe video files of"
101
+ " arbitrary length."
102
  ),
103
  allow_flagging="never",
104
  )
 
106
  with demo:
107
  gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Transcribe Audio", "Transcribe YouTube"])
108
 
109
+ demo.launch(enable_queue=True)