Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import librosa | |
import numpy as np | |
import webrtcvad | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline | |
# Model names | |
TN_MODEL_NAME = "amenIKh/Tunisian_Checkpoint12" | |
WHISPER_MODEL_NAME = "openai/whisper-small" | |
# Initialize pipelines | |
pipe_tn = pipeline( | |
task="automatic-speech-recognition", | |
model=TN_MODEL_NAME, | |
device=0 if torch.cuda.is_available() else -1, | |
) | |
# Load Whisper model and processor | |
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME) | |
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
whisper_model.to(device) | |
# Function to apply VAD | |
def apply_vad(audio, sr, frame_duration_ms=30): | |
vad = webrtcvad.Vad() | |
vad.set_mode(3) # Aggressiveness mode, higher value is more aggressive | |
frame_size = int(sr * frame_duration_ms / 1000) | |
offset = 0 | |
voiced_frames = [] | |
while offset + frame_size < len(audio): | |
frame = audio[offset:offset + frame_size].astype(np.int16) | |
is_speech = vad.is_speech(frame.tobytes(), sr) | |
if is_speech: | |
voiced_frames.append(frame) | |
offset += frame_size | |
if len(voiced_frames) == 0: | |
return audio # Return original audio if no voiced frames are detected | |
voiced_audio = np.concatenate(voiced_frames) | |
return voiced_audio | |
# Function to transcribe audio based on language | |
def transcribe_audio(audio, language): | |
try: | |
# Load audio | |
sr = 16000 # Assuming the audio is in 16kHz; adjust if necessary | |
audio, _ = librosa.load(audio, sr=sr) | |
# Apply VAD | |
voiced_audio = apply_vad(audio, sr) | |
# Select the correct model based on language | |
if language == "tn": | |
result = pipe_tn(voiced_audio) | |
transcription = result.get("text", "") | |
elif language in ["fr", "en"]: | |
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") | |
input_features = whisper_processor(voiced_audio, return_tensors="pt").input_features.to(device) | |
generated_ids = whisper_model.generate( | |
input_features, | |
forced_decoder_ids=forced_decoder_ids | |
) | |
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
else: | |
return "Unsupported language specified" | |
return transcription | |
except Exception as e: | |
return f"An unexpected error occurred: {str(e)}" | |
# Define Gradio interface | |
def gradio_interface(audio, language): | |
try: | |
# Extract the file path or microphone input from the Gradio audio input | |
if isinstance(audio, tuple): | |
temp_file_path = audio[0] # For microphone recordings, extract file path from the tuple | |
else: | |
temp_file_path = audio # For uploaded files | |
# Perform transcription | |
result = transcribe_audio(temp_file_path, language) | |
return result | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Create the Gradio app | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Audio(sources=["upload","microphone"],type="filepath", label="Upload Audio"), | |
gr.Dropdown(choices=["tn", "fr", "en"], label="Select Language") | |
], | |
outputs="text", | |
title="ASR Transcription Service", | |
description="Upload an audio file and select the language to transcribe the audio." | |
) | |
# Add the custom HTML with background image | |
iface.launch() | |