anzorq commited on
Commit
1c4ba6c
·
verified ·
1 Parent(s): 550d732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -39
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import spaces
 
2
  import gradio as gr
3
  import torch
4
  import torchaudio
5
  from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
6
- import yt_dlp
7
 
8
  model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd")
9
  processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")
@@ -11,9 +12,15 @@ processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
 
 
 
 
14
  @spaces.GPU
15
  def transcribe_speech(audio):
16
- # Load the audio file
 
 
17
  waveform, sr = torchaudio.load(audio)
18
 
19
  # Resample the audio if needed
@@ -23,57 +30,82 @@ def transcribe_speech(audio):
23
 
24
  # Convert to mono if needed
25
  if waveform.dim() > 1:
26
- waveform = torchaudio.transforms.DownmixMono()(waveform)
 
 
 
27
 
28
  # Normalize the audio
29
  waveform = waveform / torch.max(torch.abs(waveform))
30
 
31
- # Extract input features
32
- with torch.no_grad():
33
- input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features
34
- input_features = torch.from_numpy(input_features).to(device)
35
 
36
- # Generate logits using the model
37
- logits = model(input_features).logits
 
 
 
 
 
38
 
39
- # Decode the predicted ids to text
40
- pred_ids = torch.argmax(logits, dim=-1)[0]
41
- pred_text = processor.decode(pred_ids)
42
 
43
- return pred_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- @spaces.GPU
46
  def transcribe_from_youtube(url):
47
- # Download audio from YouTube using yt-dlp
48
- audio_path = f"downloaded_audio_{url.split('=')[-1]}.wav"
49
- ydl_opts = {
50
- 'format': 'bestaudio/best',
51
- 'outtmpl': audio_path,
52
- 'postprocessors': [{
53
- 'key': 'FFmpegExtractAudio',
54
- 'preferredcodec': 'wav',
55
- 'preferredquality': '192',
56
- }],
57
- 'postprocessor_args': ['-ar', '16000'], # Ensure audio is at 16000 Hz
58
- 'prefer_ffmpeg': True,
59
- }
60
-
61
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
62
- ydl.download([url])
63
-
64
- # # Check if the file exists
65
- # if not os.path.exists(audio_path):
66
- # raise FileNotFoundError(f"Failed to find the audio file {audio_path}")
67
-
68
  # Transcribe the downloaded audio
69
  transcription = transcribe_speech(audio_path)
70
-
71
- # Optionally, clean up the downloaded file
72
  os.remove(audio_path)
73
 
74
  return transcription
75
 
 
 
 
 
76
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  with gr.Tab("Microphone Input"):
78
  gr.Markdown("## Transcribe speech from microphone")
79
  mic_audio = gr.Audio(sources="microphone", type="filepath", label="Speak into your microphone")
@@ -85,9 +117,12 @@ with gr.Blocks() as demo:
85
  with gr.Tab("YouTube URL"):
86
  gr.Markdown("## Transcribe speech from YouTube video")
87
  youtube_url = gr.Textbox(label="Enter YouTube video URL")
 
 
88
  transcribe_button = gr.Button("Transcribe")
89
- transcription_output = gr.Textbox(label="Transcription")
90
 
91
  transcribe_button.click(fn=transcribe_from_youtube, inputs=youtube_url, outputs=transcription_output)
 
92
 
93
- demo.launch()
 
1
  import spaces
2
+ import os
3
  import gradio as gr
4
  import torch
5
  import torchaudio
6
  from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
7
+ from pytube import YouTube
8
 
9
  model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd")
10
  processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model.to(device)
14
 
15
+ # Chunk processing parameters
16
+ chunk_length_s = 10 # Chunk length in seconds
17
+ stride_length_s = (4, 2) # Stride lengths in seconds
18
+
19
  @spaces.GPU
20
  def transcribe_speech(audio):
21
+ if audio is None: # Handle the NoneType error for microphone input
22
+ return "No audio received."
23
+
24
  waveform, sr = torchaudio.load(audio)
25
 
26
  # Resample the audio if needed
 
30
 
31
  # Convert to mono if needed
32
  if waveform.dim() > 1:
33
+ waveform = torch.mean(waveform, dim=0)
34
+
35
+ # Ensure the waveform is a 2D tensor for chunking
36
+ waveform = waveform.unsqueeze(0) # Add a dimension if it's mono
37
 
38
  # Normalize the audio
39
  waveform = waveform / torch.max(torch.abs(waveform))
40
 
41
+ # Chunk the audio
42
+ chunks = torch.split(waveform, int(chunk_length_s * sr), dim=1)
 
 
43
 
44
+ # Process each chunk with striding
45
+ full_transcription = ""
46
+ for i, chunk in enumerate(chunks):
47
+ with torch.no_grad():
48
+ # Calculate stride lengths in frames
49
+ left_stride_frames = int(stride_length_s[0] * sr)
50
+ right_stride_frames = int(stride_length_s[1] * sr)
51
 
52
+ # Extract the effective chunk with stride
53
+ start_frame = max(0, left_stride_frames * (i - 1))
54
+ end_frame = min(chunk.size(1), chunk.size(1) - right_stride_frames * i)
55
 
56
+ # Check for negative duration before processing
57
+ if end_frame <= start_frame:
58
+ continue # Skip this chunk
59
+
60
+ effective_chunk = chunk[:, start_frame:end_frame]
61
+
62
+ # Extract input features
63
+ input_features = processor(effective_chunk, sampling_rate=16000).input_features
64
+ input_features = torch.from_numpy(input_features).to(device)
65
+
66
+ # Generate logits using the model
67
+ logits = model(input_features).logits
68
+
69
+ # Decode the predicted ids to text
70
+ pred_ids = torch.argmax(logits, dim=-1)[0]
71
+ pred_text = processor.decode(pred_ids)
72
+
73
+ # Append the chunk's transcription to the full transcription
74
+ full_transcription += pred_text
75
+
76
+ return full_transcription
77
 
 
78
  def transcribe_from_youtube(url):
79
+ # Download audio from YouTube using pytube
80
+ yt = YouTube(url)
81
+ audio_path = yt.streams.filter(only_audio=True)[0].download(filename="tmp.mp4")
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Transcribe the downloaded audio
84
  transcription = transcribe_speech(audio_path)
85
+
86
+ # Clean up the downloaded file
87
  os.remove(audio_path)
88
 
89
  return transcription
90
 
91
+ def populate_metadata(url):
92
+ yt = YouTube(url)
93
+ return yt.thumbnail_url, yt.title
94
+
95
  with gr.Blocks() as demo:
96
+ gr.HTML(
97
+ """
98
+ <div style="text-align: center; max-width: 500px; margin: 0 auto;">
99
+ <div>
100
+ <h1>Youtube Speech Transcription</h1>
101
+ </div>
102
+ <p style="margin-bottom: 10px; font-size: 94%">
103
+ Speech to text transcription of Youtube videos using Wav2Vec2-BERT
104
+ </p>
105
+ </div>
106
+ """
107
+ )
108
+
109
  with gr.Tab("Microphone Input"):
110
  gr.Markdown("## Transcribe speech from microphone")
111
  mic_audio = gr.Audio(sources="microphone", type="filepath", label="Speak into your microphone")
 
117
  with gr.Tab("YouTube URL"):
118
  gr.Markdown("## Transcribe speech from YouTube video")
119
  youtube_url = gr.Textbox(label="Enter YouTube video URL")
120
+ title = gr.Label(label="Video Title")
121
+ img = gr.Image(label="Thumbnail")
122
  transcribe_button = gr.Button("Transcribe")
123
+ transcription_output = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10)
124
 
125
  transcribe_button.click(fn=transcribe_from_youtube, inputs=youtube_url, outputs=transcription_output)
126
+ youtube_url.change(populate_metadata, inputs=[youtube_url], outputs=[img, title])
127
 
128
+ demo.launch(debug=True)