Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
import gradio as gr | |
import time | |
import numpy as np | |
import scipy.io.wavfile | |
from omegaconf import OmegaConf # β Fix: Import omegaconf | |
# β 1οΈβ£ Load Silero STT Model for CPU | |
device = torch.device("cpu") | |
torch_dtype = torch.float32 | |
# β 2οΈβ£ Load Silero Model & Decoder with `trust_repo=True` | |
torch.set_num_threads(4) | |
model, decoder, utils = torch.hub.load(repo_or_dir="snakers4/silero-models", | |
model="silero_stt", | |
language="en", | |
device=device, | |
trust_repo=True) # β Fix: Avoids untrusted repo warning | |
(read_batch, split_into_batches, read_audio, prepare_model_input) = utils | |
# β 3οΈβ£ Real-Time Streaming Transcription (Microphone) | |
def stream_transcribe(stream, new_chunk): | |
start_time = time.time() | |
try: | |
sr, y = new_chunk | |
# β Convert stereo to mono | |
if y.ndim > 1: | |
y = y.mean(axis=1) | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
# β Resample audio to 16kHz using torchaudio | |
y_tensor = torch.tensor(y) | |
y_resampled = torchaudio.functional.resample(y_tensor, orig_freq=sr, new_freq=16000).numpy() | |
# β Append to Stream | |
if stream is not None: | |
stream = np.concatenate([stream, y_resampled]) | |
else: | |
stream = y_resampled | |
# β Prepare Model Input | |
input_tensor = torch.from_numpy(stream).unsqueeze(0) | |
input_tensor = prepare_model_input(input_tensor, device=device) | |
# β Run Transcription | |
transcription = model(input_tensor) | |
text = decoder(transcription[0].cpu()) | |
latency = time.time() - start_time | |
return stream, text, f"{latency:.2f} sec" | |
except Exception as e: | |
print(f"Error: {e}") | |
return stream, str(e), "Error" | |
# β 4οΈβ£ Transcription for File Upload | |
def transcribe(inputs, previous_transcription): | |
start_time = time.time() | |
try: | |
# β Convert file input to correct format | |
sample_rate, audio_data = inputs | |
# β Resample using torchaudio (optimized) | |
audio_tensor = torch.tensor(audio_data) | |
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=16000).numpy() | |
# β Prepare Model Input | |
input_tensor = torch.from_numpy(resampled_audio).unsqueeze(0) | |
input_tensor = prepare_model_input(input_tensor, device=device) | |
# β Run Transcription | |
transcription = model(input_tensor) | |
text = decoder(transcription[0].cpu()) | |
previous_transcription += text | |
latency = time.time() - start_time | |
return previous_transcription, f"{latency:.2f} sec" | |
except Exception as e: | |
print(f"Error: {e}") | |
return previous_transcription, "Error" | |
# β 5οΈβ£ Clear Function | |
def clear(): | |
return "" | |
# β 6οΈβ£ Gradio Interface (Microphone Streaming) | |
with gr.Blocks() as microphone: | |
gr.Markdown(f"# Silero STT - Real-Time Transcription (Optimized CPU) ποΈ") | |
gr.Markdown("Using `Silero STT` for lightweight, accurate speech-to-text.") | |
with gr.Row(): | |
input_audio_microphone = gr.Audio(sources=["microphone"], type="numpy", streaming=True) | |
output = gr.Textbox(label="Live Transcription", value="") | |
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0") | |
with gr.Row(): | |
clear_button = gr.Button("Clear Output") | |
state = gr.State() | |
input_audio_microphone.stream( | |
stream_transcribe, [state, input_audio_microphone], | |
[state, output, latency_textbox], time_limit=30, stream_every=1 | |
) | |
clear_button.click(clear, outputs=[output]) | |
# β 7οΈβ£ Gradio Interface (File Upload) | |
with gr.Blocks() as file: | |
gr.Markdown(f"# Upload Audio File for Transcription π΅") | |
gr.Markdown("Using `Silero STT` for offline, high-accuracy transcription.") | |
with gr.Row(): | |
input_audio = gr.Audio(sources=["upload"], type="numpy") | |
output = gr.Textbox(label="Transcription", value="") | |
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0") | |
with gr.Row(): | |
submit_button = gr.Button("Submit") | |
clear_button = gr.Button("Clear Output") | |
submit_button.click(transcribe, [input_audio, output], [output, latency_textbox]) | |
clear_button.click(clear, outputs=[output]) | |
# β 8οΈβ£ Final Gradio App | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
gr.TabbedInterface([microphone, file], ["Microphone", "Upload Audio"]) | |
# β 9οΈβ£ Run Gradio Locally | |
if __name__ == "__main__": | |
demo.launch() | |