amenIKh's picture
Added ASR model files
ef7037b
raw
history blame
3.63 kB
import gradio as gr
import torch
import librosa
import numpy as np
import webrtcvad
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
# Model names
TN_MODEL_NAME = "amenIKh/Tunisian_Checkpoint12"
WHISPER_MODEL_NAME = "openai/whisper-small"
# Initialize pipelines
pipe_tn = pipeline(
task="automatic-speech-recognition",
model=TN_MODEL_NAME,
device=0 if torch.cuda.is_available() else -1,
)
# Load Whisper model and processor
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME)
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model.to(device)
# Function to apply VAD
def apply_vad(audio, sr, frame_duration_ms=30):
vad = webrtcvad.Vad()
vad.set_mode(3) # Aggressiveness mode, higher value is more aggressive
frame_size = int(sr * frame_duration_ms / 1000)
offset = 0
voiced_frames = []
while offset + frame_size < len(audio):
frame = audio[offset:offset + frame_size].astype(np.int16)
is_speech = vad.is_speech(frame.tobytes(), sr)
if is_speech:
voiced_frames.append(frame)
offset += frame_size
if len(voiced_frames) == 0:
return audio # Return original audio if no voiced frames are detected
voiced_audio = np.concatenate(voiced_frames)
return voiced_audio
# Function to transcribe audio based on language
def transcribe_audio(audio, language):
try:
# Load audio
sr = 16000 # Assuming the audio is in 16kHz; adjust if necessary
audio, _ = librosa.load(audio, sr=sr)
# Apply VAD
voiced_audio = apply_vad(audio, sr)
# Select the correct model based on language
if language == "tn":
result = pipe_tn(voiced_audio)
transcription = result.get("text", "")
elif language in ["fr", "en"]:
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe")
input_features = whisper_processor(voiced_audio, return_tensors="pt").input_features.to(device)
generated_ids = whisper_model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids
)
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
return "Unsupported language specified"
return transcription
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
# Define Gradio interface
def gradio_interface(audio, language):
try:
# Extract the file path or microphone input from the Gradio audio input
if isinstance(audio, tuple):
temp_file_path = audio[0] # For microphone recordings, extract file path from the tuple
else:
temp_file_path = audio # For uploaded files
# Perform transcription
result = transcribe_audio(temp_file_path, language)
return result
except Exception as e:
return f"An error occurred: {str(e)}"
# Create the Gradio app
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Audio(sources=["upload","microphone"],type="filepath", label="Upload Audio"),
gr.Dropdown(choices=["tn", "fr", "en"], label="Select Language")
],
outputs="text",
title="ASR Transcription Service",
description="Upload an audio file and select the language to transcribe the audio."
)
# Add the custom HTML with background image
iface.launch()