ASR_gradio / audio_processing.py
Kr08's picture
Update audio_processing.py
c6007d5 verified
raw
history blame
6.66 kB
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH = 5
OVERLAP = 2
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@spaces.GPU(duration=60)
def load_whisper_model(model_size="small"):
logger.info(f"Loading Whisper model (size: {model_size})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
try:
model = whisperx.load_model(model_size, device, compute_type=compute_type)
logger.info(f"Whisper model loaded successfully on {device}")
return model
except RuntimeError as e:
logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
device = "cpu"
compute_type = "int8"
model = whisperx.load_model(model_size, device, compute_type=compute_type)
logger.info("Whisper model loaded successfully on CPU")
return model
@spaces.GPU(duration=60)
def load_diarization_pipeline():
logger.info("Loading diarization pipeline...")
try:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
logger.info("Diarization pipeline loaded successfully")
return pipeline
except Exception as e:
logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
return None
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
if 'translated' in segment:
merged[-1]['translated'] = merged_translated
else:
merged.append(segment)
return merged
def get_most_common_speaker(diarization_result, start_time, end_time):
speakers = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= end_time and turn.end >= start_time:
speakers.append(speaker)
return max(set(speakers), key=speakers.count) if speakers else "Unknown"
def split_audio(audio, max_duration=30):
sample_rate = 16000
max_samples = max_duration * sample_rate
if len(audio) <= max_samples:
return [audio]
splits = []
for i in range(0, len(audio), max_samples):
splits.append(audio[i:i+max_samples])
return splits
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}")
start_time = time.time()
try:
whisper_model = load_whisper_model(model_size)
audio = whisperx.load_audio(audio_file)
audio_splits = split_audio(audio)
diarization_result = None
if use_diarization:
diarization_pipeline = load_diarization_pipeline()
if diarization_pipeline is not None:
try:
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
except Exception as e:
logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
language_segments = []
final_segments = []
for i, audio_split in enumerate(audio_splits):
logger.info(f"Processing split {i+1}/{len(audio_splits)}")
result = whisper_model.transcribe(audio_split)
lang = result["language"]
for segment in result["segments"]:
segment_start = segment["start"] + (i * 30)
segment_end = segment["end"] + (i * 30)
speaker = "Unknown"
if diarization_result is not None:
speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
final_segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": speaker,
"text": segment["text"],
}
if translate:
translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
final_segment["translated"] = translation["text"]
final_segments.append(final_segment)
language_segments.append({
"language": lang,
"start": i * 30,
"end": min((i + 1) * 30, len(audio) / 16000)
})
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise