amenIKh's picture
delete access token from env
1d7c7c7
import gradio as gr
import torch
import librosa
import numpy as np
import webrtcvad
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
# Model names
TN_MODEL_NAME = "amenIKh/Tunisian_Checkpoint12"
WHISPER_MODEL_NAME = "openai/whisper-small"
# Initialize pipelines
pipe_tn = pipeline(
task="automatic-speech-recognition",
model=TN_MODEL_NAME,
device=0 if torch.cuda.is_available() else -1,
)
# Load Whisper model and processor
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME)
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model.to(device)
# Function to apply VAD
def apply_vad(audio, sr, frame_duration_ms=30):
vad = webrtcvad.Vad()
vad.set_mode(3) # Aggressiveness mode, higher value is more aggressive
frame_size = int(sr * frame_duration_ms / 1000)
offset = 0
voiced_frames = []
while offset + frame_size < len(audio):
frame = audio[offset:offset + frame_size].astype(np.int16)
is_speech = vad.is_speech(frame.tobytes(), sr)
if is_speech:
voiced_frames.append(frame)
offset += frame_size
if len(voiced_frames) == 0:
return audio # Return original audio if no voiced frames are detected
voiced_audio = np.concatenate(voiced_frames)
return voiced_audio
# Function to transcribe audio based on language
def transcribe_audio(audio, language):
try:
# Load audio
sr = 16000 # Assuming the audio is in 16kHz; adjust if necessary
audio, _ = librosa.load(audio, sr=sr)
# Apply VAD
voiced_audio = apply_vad(audio, sr)
# Select the correct model based on language
if language == "tn":
result = pipe_tn(voiced_audio)
transcription = result.get("text", "")
elif language in ["fr", "en"]:
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe")
input_features = whisper_processor(voiced_audio, return_tensors="pt").input_features.to(device)
generated_ids = whisper_model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids
)
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
return "Unsupported language specified"
return transcription
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
# Define Gradio interface
def gradio_interface(audio, language):
try:
# Extract the file path or microphone input from the Gradio audio input
if isinstance(audio, tuple):
temp_file_path = audio[0] # For microphone recordings, extract file path from the tuple
else:
temp_file_path = audio # For uploaded files
# Perform transcription
result = transcribe_audio(temp_file_path, language)
return result
except Exception as e:
return f"An error occurred: {str(e)}"
# Create the Gradio app
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Audio(sources=["upload","microphone"],type="filepath", label="Upload Audio"),
gr.Dropdown(choices=["tn", "fr", "en"], label="Select Language")
],
outputs="text",
title="ASR Transcription Service",
description="Upload an audio file and select the language to transcribe the audio."
)
# Add the custom HTML with background image
iface.launch()