ASR_gradio / audio_processing.py
Kr08's picture
Optimized audio_processing.py with optional diarization
759bce7 verified
raw
history blame
6.66 kB
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH = 30
OVERLAP = 2
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@spaces.GPU(duration=60)
def load_whisper_model(model_size="small"):
logger.info(f"Loading Whisper model (size: {model_size})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
try:
model = whisperx.load_model(model_size, device, compute_type=compute_type)
logger.info(f"Whisper model loaded successfully on {device}")
return model
except RuntimeError as e:
logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
device = "cpu"
compute_type = "int8"
model = whisperx.load_model(model_size, device, compute_type=compute_type)
logger.info("Whisper model loaded successfully on CPU")
return model
@spaces.GPU(duration=60)
def load_diarization_pipeline():
logger.info("Loading diarization pipeline...")
try:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
logger.info("Diarization pipeline loaded successfully")
return pipeline
except Exception as e:
logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
return None
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
if 'translated' in segment:
merged[-1]['translated'] = merged_translated
else:
merged.append(segment)
return merged
def get_most_common_speaker(diarization_result, start_time, end_time):
speakers = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= end_time and turn.end >= start_time:
speakers.append(speaker)
return max(set(speakers), key=speakers.count) if speakers else "Unknown"
def split_audio(audio, max_duration=30):
sample_rate = 16000
max_samples = max_duration * sample_rate
if len(audio) <= max_samples:
return [audio]
splits = []
for i in range(0, len(audio), max_samples):
splits.append(audio[i:i+max_samples])
return splits
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}")
start_time = time.time()
try:
whisper_model = load_whisper_model(model_size)
audio = whisperx.load_audio(audio_file)
audio_splits = split_audio(audio)
diarization_result = None
if use_diarization:
diarization_pipeline = load_diarization_pipeline()
if diarization_pipeline is not None:
try:
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
except Exception as e:
logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
language_segments = []
final_segments = []
for i, audio_split in enumerate(audio_splits):
logger.info(f"Processing split {i+1}/{len(audio_splits)}")
result = whisper_model.transcribe(audio_split)
lang = result["language"]
for segment in result["segments"]:
segment_start = segment["start"] + (i * 30)
segment_end = segment["end"] + (i * 30)
speaker = "Unknown"
if diarization_result is not None:
speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
final_segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": speaker,
"text": segment["text"],
}
if translate:
translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
final_segment["translated"] = translation["text"]
final_segments.append(final_segment)
language_segments.append({
"language": lang,
"start": i * 30,
"end": min((i + 1) * 30, len(audio) / 16000)
})
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise