File size: 4,564 Bytes
bfb5ccb
1c4ba6c
8ca2e83
 
 
 
1c4ba6c
8ca2e83
1c803c5
 
8ca2e83
0c872e7
 
 
1c4ba6c
 
 
 
bfb5ccb
8ca2e83
1c4ba6c
 
 
8ca2e83
 
 
550d732
 
 
8ca2e83
 
 
1c4ba6c
 
 
 
8ca2e83
 
 
 
1c4ba6c
 
550d732
1c4ba6c
 
 
 
 
 
 
8ca2e83
1c4ba6c
 
 
8ca2e83
1c4ba6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ca2e83
eaed2c2
1c4ba6c
 
 
 
eaed2c2
550d732
1c4ba6c
 
550d732
 
 
eaed2c2
1c4ba6c
 
 
 
eaed2c2
1c4ba6c
 
 
 
 
 
 
 
 
 
 
 
 
eaed2c2
 
6fd478d
eaed2c2
 
 
 
 
 
 
 
1c4ba6c
 
eaed2c2
1c4ba6c
eaed2c2
 
1c4ba6c
8ca2e83
1c4ba6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import spaces
import os
import gradio as gr
import torch
import torchaudio
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
from pytube import YouTube

model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd")
processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Chunk processing parameters
chunk_length_s = 10  # Chunk length in seconds
stride_length_s = (4, 2)  # Stride lengths in seconds

@spaces.GPU
def transcribe_speech(audio):
    if audio is None:  # Handle the NoneType error for microphone input
        return "No audio received."

    waveform, sr = torchaudio.load(audio)

    # Resample the audio if needed
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
        waveform = resampler(waveform)

    # Convert to mono if needed
    if waveform.dim() > 1:
        waveform = torch.mean(waveform, dim=0)
    
    # Ensure the waveform is a 2D tensor for chunking
    waveform = waveform.unsqueeze(0)  # Add a dimension if it's mono

    # Normalize the audio
    waveform = waveform / torch.max(torch.abs(waveform))

    # Chunk the audio
    chunks = torch.split(waveform, int(chunk_length_s * sr), dim=1)

    # Process each chunk with striding
    full_transcription = ""
    for i, chunk in enumerate(chunks):
        with torch.no_grad():
            # Calculate stride lengths in frames
            left_stride_frames = int(stride_length_s[0] * sr)
            right_stride_frames = int(stride_length_s[1] * sr)

            # Extract the effective chunk with stride
            start_frame = max(0, left_stride_frames * (i - 1))
            end_frame = min(chunk.size(1), chunk.size(1) - right_stride_frames * i)

            # Check for negative duration before processing
            if end_frame <= start_frame:
                continue  # Skip this chunk

            effective_chunk = chunk[:, start_frame:end_frame]

            # Extract input features
            input_features = processor(effective_chunk, sampling_rate=16000).input_features
            input_features = torch.from_numpy(input_features).to(device)

            # Generate logits using the model
            logits = model(input_features).logits

        # Decode the predicted ids to text
        pred_ids = torch.argmax(logits, dim=-1)[0]
        pred_text = processor.decode(pred_ids)

        # Append the chunk's transcription to the full transcription
        full_transcription += pred_text

    return full_transcription

def transcribe_from_youtube(url):
    # Download audio from YouTube using pytube
    yt = YouTube(url)
    audio_path = yt.streams.filter(only_audio=True)[0].download(filename="tmp.mp4")

    # Transcribe the downloaded audio
    transcription = transcribe_speech(audio_path)

    # Clean up the downloaded file
    os.remove(audio_path)

    return transcription

def populate_metadata(url):
    yt = YouTube(url)
    return yt.thumbnail_url, yt.title

with gr.Blocks() as demo:
    gr.HTML(
        """
            <div style="text-align: center; max-width: 500px; margin: 0 auto;">
              <div>
                <h1>Youtube Speech Transcription</h1>
              </div>
              <p style="margin-bottom: 10px; font-size: 94%">
                Speech to text transcription of Youtube videos using Wav2Vec2-BERT
              </p>
            </div>
        """
    )
    
    with gr.Tab("Microphone Input"):
        gr.Markdown("## Transcribe speech from microphone")
        mic_audio = gr.Audio(sources="microphone", type="filepath", label="Speak into your microphone")
        transcribe_button = gr.Button("Transcribe")
        transcription_output = gr.Textbox(label="Transcription")
        
        transcribe_button.click(fn=transcribe_speech, inputs=mic_audio, outputs=transcription_output)

    with gr.Tab("YouTube URL"):
        gr.Markdown("## Transcribe speech from YouTube video")
        youtube_url = gr.Textbox(label="Enter YouTube video URL")
        title = gr.Label(label="Video Title")
        img = gr.Image(label="Thumbnail")
        transcribe_button = gr.Button("Transcribe")
        transcription_output = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10)
        
        transcribe_button.click(fn=transcribe_from_youtube, inputs=youtube_url, outputs=transcription_output)
        youtube_url.change(populate_metadata, inputs=[youtube_url], outputs=[img, title])

demo.launch(debug=True)