Spaces:
Running
on
Zero
Running
on
Zero
from pydub import AudioSegment | |
import os | |
import torchaudio | |
import torch | |
import re | |
import whisper_timestamped as whisper_ts | |
from faster_whisper import WhisperModel | |
from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V2, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, FAKE_AUDIO_PATH, RESAMPLING_FREQ | |
import time | |
def get_settings(): | |
if DEBUG_MODE: print(f"Entering get_settings function...") | |
# HACK hardcoding this to try | |
is_cuda_available = True #torch.cuda.is_available() | |
if is_cuda_available: | |
device = "cuda" | |
compute_type = "float16" | |
else: | |
device = "cpu" | |
compute_type = "int8" | |
if DEBUG_MODE: print(f"is_cuda_available: {is_cuda_available}") | |
if DEBUG_MODE: print(f"device: {device}") | |
if DEBUG_MODE: print(f"compute_type: {compute_type}") | |
if DEBUG_MODE: print(f"Exited get_settings function.") | |
return device, compute_type | |
def load_model(use_v2_fast, device, compute_type): | |
if DEBUG_MODE: print(f"Entering load_model function...") | |
if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}") | |
if use_v2_fast: | |
if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2_FAST} using {device} with {compute_type}...") | |
model = WhisperModel( | |
MODEL_PATH_V2_FAST, | |
device = device, | |
compute_type = compute_type, | |
) | |
else: | |
if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2} using {device} with {compute_type}...") | |
# TODO add compute_type to load model | |
model = whisper_ts.load_model( | |
MODEL_PATH_V2, | |
device = device, | |
) | |
# HACK we need to do this for strange reasons. | |
# If we don't do this, we get: | |
#Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory | |
fake_model = whisper_ts.load_model(MODEL_PATH_V2, device=device) | |
if DEBUG_MODE: print(f"Exiting load_model function...") | |
return model, fake_model | |
def split_input_stereo_channels(audio_path): | |
if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...") | |
ext = os.path.splitext(audio_path)[1].lower() | |
if ext == ".wav": | |
audio = AudioSegment.from_wav(audio_path) | |
elif ext == ".mp3": | |
audio = AudioSegment.from_file(audio_path, format="mp3") | |
else: | |
raise ValueError(f"Unsupported file format for: {audio_path}") | |
channels = audio.split_to_mono() | |
if len(channels) != 2: | |
raise ValueError(f"Audio {audio_path} has {len(channels)} channels (instead of 2).") | |
channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right | |
channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left | |
if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.") | |
def format_audio(audio_path): | |
if DEBUG_MODE: print(f"Entering format_audio function...") | |
input_audio, sample_rate = torchaudio.load(audio_path) | |
if input_audio.shape[0] == 2: | |
input_audio = torch.mean(input_audio, dim=0, keepdim=True) | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ) | |
input_audio = resampler(input_audio) | |
input_audio = input_audio.squeeze() | |
if DEBUG_MODE: print(f"Exited format_audio function.") | |
return input_audio, RESAMPLING_FREQ | |
def process_waveforms(): | |
if DEBUG_MODE: print(f"Entering process_waveforms function...") | |
left_waveform, _ = format_audio(LEFT_CHANNEL_TEMP_PATH) | |
right_waveform, _ = format_audio(RIGHT_CHANNEL_TEMP_PATH) | |
# TODO should this be equal to compute_type? | |
left_waveform = left_waveform.numpy().astype("float16") | |
right_waveform = right_waveform.numpy().astype("float16") | |
if DEBUG_MODE: print(f"Exited process_waveforms function.") | |
return left_waveform, right_waveform | |
def transcribe_audio_no_fast_model(model, audio_path): | |
if DEBUG_MODE: print(f"Entering transcribe_audio_no_fast_model function...") | |
result = whisper_ts.transcribe( | |
model, | |
audio_path, | |
beam_size=5, | |
best_of=5, | |
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), | |
vad=False, | |
detect_disfluencies=True, | |
) | |
words = [] | |
for segment in result.get('segments', []): | |
for word in segment.get('words', []): | |
word_text = word.get('word', '').strip() | |
if word_text.startswith(' '): | |
word_text = word_text[1:] | |
words.append({ | |
'word': word_text, | |
'start': word.get('start', 0), | |
'end': word.get('end', 0), | |
'confidence': word.get('confidence', 0) | |
}) | |
return { | |
'audio_path': audio_path, | |
'text': result['text'].strip(), | |
'segments': result.get('segments', []), | |
'words': words, | |
'duration': result.get('duration', 0), | |
'success': True | |
} | |
if DEBUG_MODE: print(f"Exited transcribe_audio_no_fast_model function.") | |
def transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model): | |
if DEBUG_MODE: print(f"Entering transcribe_channels function...") | |
# HACK we need to do this for strange reasons. | |
# If we don't do this, we get: | |
#Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory | |
fake_result = whisper_ts.transcribe( | |
fake_model, | |
FAKE_AUDIO_PATH, | |
beam_size=1, | |
) | |
if DEBUG_MODE: print(f"Preparing to transcribe...") | |
if use_v2_fast: | |
left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe") | |
right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe") | |
left_result = list(left_result) | |
right_result = list(right_result) | |
else: | |
left_result = transcribe_audio_no_fast_model(model, left_waveform) | |
right_result = transcribe_audio_no_fast_model(model, right_waveform) | |
if DEBUG_MODE: print(f"Exited transcribe_channels function.") | |
return left_result, right_result | |
# TODO refactor and rename this function | |
def post_process_transcription(transcription, max_repeats=2): | |
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) | |
cleaned_tokens = [] | |
repetition_count = 0 | |
previous_token = None | |
for token in tokens: | |
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) | |
if reduced_token == previous_token: | |
repetition_count += 1 | |
if repetition_count <= max_repeats: | |
cleaned_tokens.append(reduced_token) | |
else: | |
repetition_count = 1 | |
cleaned_tokens.append(reduced_token) | |
previous_token = reduced_token | |
cleaned_transcription = " ".join(cleaned_tokens) | |
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() | |
return cleaned_transcription | |
# TODO not used right now, decide to use it or not | |
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: | |
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) | |
merged_transcription = '' | |
current_speaker = None | |
current_segment = [] | |
for i in range(1, len(segments) - 1, 2): | |
speaker_tag = segments[i] | |
text = segments[i + 1].strip() | |
speaker = re.search(r'\d{2}', speaker_tag).group() | |
if speaker == current_speaker: | |
current_segment.append(text) | |
else: | |
if current_speaker is not None: | |
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' | |
current_speaker = speaker | |
current_segment = [text] | |
if current_speaker is not None: | |
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' | |
return merged_transcription.strip() | |
def get_segments(result, speaker_label, use_v2_fast): | |
if DEBUG_MODE: print(f"Entering get_segments function...") | |
if use_v2_fast: | |
segments = result | |
final_segments = [ | |
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip())) | |
for seg in segments if seg.text | |
] | |
else: | |
segments = result.get("segments", []) | |
if not segments: | |
final_segments = [] | |
final_segments = [ | |
(seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, | |
post_process_transcription(seg.get("text", "").strip())) | |
for seg in segments if seg.get("text") | |
] | |
if DEBUG_MODE: print(f"EXited get_segments function.") | |
return final_segments | |
def post_process_transcripts(left_result, right_result, use_v2_fast): | |
if DEBUG_MODE: print(f"Entering post_process_transcripts function...") | |
left_segs = get_segments(left_result, "Speaker 1", use_v2_fast) | |
right_segs = get_segments(right_result, "Speaker 2", use_v2_fast) | |
merged_transcript = sorted( | |
left_segs + right_segs, | |
key=lambda x: float(x[0]) if x[0] is not None else float("inf") | |
) | |
clean_output = "" | |
for start, end, speaker, text in merged_transcript: | |
clean_output += f"[{speaker}]: {text}\n" | |
clean_output = clean_output.strip() | |
if DEBUG_MODE: print(f"Exited post_process_transcripts function.") | |
return clean_output | |
def cleanup_temp_files(*file_paths): | |
if DEBUG_MODE: print(f"Entered cleanup_temp_files function...") | |
if DEBUG_MODE: print(f"File paths to remove: {file_paths}") | |
for path in file_paths: | |
if path and os.path.exists(path): | |
if DEBUG_MODE: print(f"Removing path: {path}") | |
os.remove(path) | |
if DEBUG_MODE: print(f"Exited cleanup_temp_files function.") | |
def generate(audio_path, use_v2_fast): | |
if DEBUG_MODE: print(f"Entering generate function...") | |
start = time.time() | |
device, compute_type = get_settings() | |
model, fake_model = load_model(use_v2_fast, device, compute_type) | |
split_input_stereo_channels(audio_path) | |
left_waveform, right_waveform = process_waveforms() | |
left_result, right_result = transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model) | |
output = post_process_transcripts(left_result, right_result, use_v2_fast) | |
cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH) | |
end = time.time() | |
elapsed_secs = end - start | |
if DEBUG_MODE: print(f"elapsed_secs: {elapsed_secs}") | |
if DEBUG_MODE: print(f"Exited generate function.") | |
return output | |