anzorq's picture
Update app.py
1c4ba6c verified
raw
history blame
4.56 kB
import spaces
import os
import gradio as gr
import torch
import torchaudio
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
from pytube import YouTube
model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd")
processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Chunk processing parameters
chunk_length_s = 10 # Chunk length in seconds
stride_length_s = (4, 2) # Stride lengths in seconds
@spaces.GPU
def transcribe_speech(audio):
if audio is None: # Handle the NoneType error for microphone input
return "No audio received."
waveform, sr = torchaudio.load(audio)
# Resample the audio if needed
if sr != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
waveform = resampler(waveform)
# Convert to mono if needed
if waveform.dim() > 1:
waveform = torch.mean(waveform, dim=0)
# Ensure the waveform is a 2D tensor for chunking
waveform = waveform.unsqueeze(0) # Add a dimension if it's mono
# Normalize the audio
waveform = waveform / torch.max(torch.abs(waveform))
# Chunk the audio
chunks = torch.split(waveform, int(chunk_length_s * sr), dim=1)
# Process each chunk with striding
full_transcription = ""
for i, chunk in enumerate(chunks):
with torch.no_grad():
# Calculate stride lengths in frames
left_stride_frames = int(stride_length_s[0] * sr)
right_stride_frames = int(stride_length_s[1] * sr)
# Extract the effective chunk with stride
start_frame = max(0, left_stride_frames * (i - 1))
end_frame = min(chunk.size(1), chunk.size(1) - right_stride_frames * i)
# Check for negative duration before processing
if end_frame <= start_frame:
continue # Skip this chunk
effective_chunk = chunk[:, start_frame:end_frame]
# Extract input features
input_features = processor(effective_chunk, sampling_rate=16000).input_features
input_features = torch.from_numpy(input_features).to(device)
# Generate logits using the model
logits = model(input_features).logits
# Decode the predicted ids to text
pred_ids = torch.argmax(logits, dim=-1)[0]
pred_text = processor.decode(pred_ids)
# Append the chunk's transcription to the full transcription
full_transcription += pred_text
return full_transcription
def transcribe_from_youtube(url):
# Download audio from YouTube using pytube
yt = YouTube(url)
audio_path = yt.streams.filter(only_audio=True)[0].download(filename="tmp.mp4")
# Transcribe the downloaded audio
transcription = transcribe_speech(audio_path)
# Clean up the downloaded file
os.remove(audio_path)
return transcription
def populate_metadata(url):
yt = YouTube(url)
return yt.thumbnail_url, yt.title
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 500px; margin: 0 auto;">
<div>
<h1>Youtube Speech Transcription</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
Speech to text transcription of Youtube videos using Wav2Vec2-BERT
</p>
</div>
"""
)
with gr.Tab("Microphone Input"):
gr.Markdown("## Transcribe speech from microphone")
mic_audio = gr.Audio(sources="microphone", type="filepath", label="Speak into your microphone")
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription")
transcribe_button.click(fn=transcribe_speech, inputs=mic_audio, outputs=transcription_output)
with gr.Tab("YouTube URL"):
gr.Markdown("## Transcribe speech from YouTube video")
youtube_url = gr.Textbox(label="Enter YouTube video URL")
title = gr.Label(label="Video Title")
img = gr.Image(label="Thumbnail")
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10)
transcribe_button.click(fn=transcribe_from_youtube, inputs=youtube_url, outputs=transcription_output)
youtube_url.change(populate_metadata, inputs=[youtube_url], outputs=[img, title])
demo.launch(debug=True)