asr-inference / whisper_cs_dev.py
federicocosta1989's picture
Update whisper_cs_dev.py
30c975d verified
from pydub import AudioSegment
import os
import torchaudio
import torch
import re
import whisper_timestamped as whisper_ts
from faster_whisper import WhisperModel
from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V2, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, FAKE_AUDIO_PATH, RESAMPLING_FREQ
import time
def get_settings():
if DEBUG_MODE: print(f"Entering get_settings function...")
# HACK hardcoding this to try
is_cuda_available = True #torch.cuda.is_available()
if is_cuda_available:
device = "cuda"
compute_type = "float16"
else:
device = "cpu"
compute_type = "int8"
if DEBUG_MODE: print(f"is_cuda_available: {is_cuda_available}")
if DEBUG_MODE: print(f"device: {device}")
if DEBUG_MODE: print(f"compute_type: {compute_type}")
if DEBUG_MODE: print(f"Exited get_settings function.")
return device, compute_type
def load_model(use_v2_fast, device, compute_type):
if DEBUG_MODE: print(f"Entering load_model function...")
if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}")
if use_v2_fast:
if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2_FAST} using {device} with {compute_type}...")
model = WhisperModel(
MODEL_PATH_V2_FAST,
device = device,
compute_type = compute_type,
)
else:
if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2} using {device} with {compute_type}...")
# TODO add compute_type to load model
model = whisper_ts.load_model(
MODEL_PATH_V2,
device = device,
)
# HACK we need to do this for strange reasons.
# If we don't do this, we get:
#Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory
fake_model = whisper_ts.load_model(MODEL_PATH_V2, device=device)
if DEBUG_MODE: print(f"Exiting load_model function...")
return model, fake_model
def split_input_stereo_channels(audio_path):
if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...")
ext = os.path.splitext(audio_path)[1].lower()
if ext == ".wav":
audio = AudioSegment.from_wav(audio_path)
elif ext == ".mp3":
audio = AudioSegment.from_file(audio_path, format="mp3")
else:
raise ValueError(f"Unsupported file format for: {audio_path}")
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"Audio {audio_path} has {len(channels)} channels (instead of 2).")
channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.")
def format_audio(audio_path):
if DEBUG_MODE: print(f"Entering format_audio function...")
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2:
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ)
input_audio = resampler(input_audio)
input_audio = input_audio.squeeze()
if DEBUG_MODE: print(f"Exited format_audio function.")
return input_audio, RESAMPLING_FREQ
def process_waveforms():
if DEBUG_MODE: print(f"Entering process_waveforms function...")
left_waveform, _ = format_audio(LEFT_CHANNEL_TEMP_PATH)
right_waveform, _ = format_audio(RIGHT_CHANNEL_TEMP_PATH)
# TODO should this be equal to compute_type?
left_waveform = left_waveform.numpy().astype("float16")
right_waveform = right_waveform.numpy().astype("float16")
if DEBUG_MODE: print(f"Exited process_waveforms function.")
return left_waveform, right_waveform
def transcribe_audio_no_fast_model(model, audio_path):
if DEBUG_MODE: print(f"Entering transcribe_audio_no_fast_model function...")
result = whisper_ts.transcribe(
model,
audio_path,
beam_size=5,
best_of=5,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
vad=False,
detect_disfluencies=True,
)
words = []
for segment in result.get('segments', []):
for word in segment.get('words', []):
word_text = word.get('word', '').strip()
if word_text.startswith(' '):
word_text = word_text[1:]
words.append({
'word': word_text,
'start': word.get('start', 0),
'end': word.get('end', 0),
'confidence': word.get('confidence', 0)
})
return {
'audio_path': audio_path,
'text': result['text'].strip(),
'segments': result.get('segments', []),
'words': words,
'duration': result.get('duration', 0),
'success': True
}
if DEBUG_MODE: print(f"Exited transcribe_audio_no_fast_model function.")
def transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model):
if DEBUG_MODE: print(f"Entering transcribe_channels function...")
# HACK we need to do this for strange reasons.
# If we don't do this, we get:
#Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory
fake_result = whisper_ts.transcribe(
fake_model,
FAKE_AUDIO_PATH,
beam_size=1,
)
if DEBUG_MODE: print(f"Preparing to transcribe...")
if use_v2_fast:
left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
left_result = list(left_result)
right_result = list(right_result)
else:
left_result = transcribe_audio_no_fast_model(model, left_waveform)
right_result = transcribe_audio_no_fast_model(model, right_waveform)
if DEBUG_MODE: print(f"Exited transcribe_channels function.")
return left_result, right_result
# TODO refactor and rename this function
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
# TODO not used right now, decide to use it or not
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
merged_transcription = ''
current_speaker = None
current_segment = []
for i in range(1, len(segments) - 1, 2):
speaker_tag = segments[i]
text = segments[i + 1].strip()
speaker = re.search(r'\d{2}', speaker_tag).group()
if speaker == current_speaker:
current_segment.append(text)
else:
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
current_speaker = speaker
current_segment = [text]
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
return merged_transcription.strip()
def get_segments(result, speaker_label, use_v2_fast):
if DEBUG_MODE: print(f"Entering get_segments function...")
if use_v2_fast:
segments = result
final_segments = [
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
for seg in segments if seg.text
]
else:
segments = result.get("segments", [])
if not segments:
final_segments = []
final_segments = [
(seg.get("start", 0.0), seg.get("end", 0.0), speaker_label,
post_process_transcription(seg.get("text", "").strip()))
for seg in segments if seg.get("text")
]
if DEBUG_MODE: print(f"EXited get_segments function.")
return final_segments
def post_process_transcripts(left_result, right_result, use_v2_fast):
if DEBUG_MODE: print(f"Entering post_process_transcripts function...")
left_segs = get_segments(left_result, "Speaker 1", use_v2_fast)
right_segs = get_segments(right_result, "Speaker 2", use_v2_fast)
merged_transcript = sorted(
left_segs + right_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
clean_output = ""
for start, end, speaker, text in merged_transcript:
clean_output += f"[{speaker}]: {text}\n"
clean_output = clean_output.strip()
if DEBUG_MODE: print(f"Exited post_process_transcripts function.")
return clean_output
def cleanup_temp_files(*file_paths):
if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
if DEBUG_MODE: print(f"File paths to remove: {file_paths}")
for path in file_paths:
if path and os.path.exists(path):
if DEBUG_MODE: print(f"Removing path: {path}")
os.remove(path)
if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
def generate(audio_path, use_v2_fast):
if DEBUG_MODE: print(f"Entering generate function...")
start = time.time()
device, compute_type = get_settings()
model, fake_model = load_model(use_v2_fast, device, compute_type)
split_input_stereo_channels(audio_path)
left_waveform, right_waveform = process_waveforms()
left_result, right_result = transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model)
output = post_process_transcripts(left_result, right_result, use_v2_fast)
cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
end = time.time()
elapsed_secs = end - start
if DEBUG_MODE: print(f"elapsed_secs: {elapsed_secs}")
if DEBUG_MODE: print(f"Exited generate function.")
return output