terry-li-hm
Update `sv.py`
8a46051
raw
history blame
11.1 kB
import datetime
import math
import os
import numpy as np
import torch
import torchaudio
from funasr import AutoModel
from pyannote.audio import Audio, Pipeline
from pyannote.core import Segment
# Load models
model = AutoModel(
model="FunAudioLLM/SenseVoiceSmall",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_kwargs={"max_single_segment_time": 30000},
hub="hf",
device="cuda" if torch.cuda.is_available() else "cpu",
)
pyannote_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN")
)
if torch.cuda.is_available():
pyannote_pipeline.to(torch.device("cuda"))
# Emoji dictionaries and formatting functions
emo_dict = {
"<|HAPPY|>": "๐Ÿ˜Š",
"<|SAD|>": "๐Ÿ˜”",
"<|ANGRY|>": "๐Ÿ˜ก",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "๐Ÿ˜ฐ",
"<|DISGUSTED|>": "๐Ÿคข",
"<|SURPRISED|>": "๐Ÿ˜ฎ",
}
event_dict = {
"<|BGM|>": "๐ŸŽผ",
"<|Speech|>": "",
"<|Applause|>": "๐Ÿ‘",
"<|Laughter|>": "๐Ÿ˜€",
"<|Cry|>": "๐Ÿ˜ญ",
"<|Sneeze|>": "๐Ÿคง",
"<|Breath|>": "",
"<|Cough|>": "๐Ÿคง",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "โ“",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "๐Ÿ˜Š",
"<|SAD|>": "๐Ÿ˜”",
"<|ANGRY|>": "๐Ÿ˜ก",
"<|NEUTRAL|>": "",
"<|BGM|>": "๐ŸŽผ",
"<|Speech|>": "",
"<|Applause|>": "๐Ÿ‘",
"<|Laughter|>": "๐Ÿ˜€",
"<|FEARFUL|>": "๐Ÿ˜ฐ",
"<|DISGUSTED|>": "๐Ÿคข",
"<|SURPRISED|>": "๐Ÿ˜ฎ",
"<|Cry|>": "๐Ÿ˜ญ",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "๐Ÿคง",
"<|Breath|>": "",
"<|Cough|>": "๐Ÿ˜ท",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emo_set = {"๐Ÿ˜Š", "๐Ÿ˜”", "๐Ÿ˜ก", "๐Ÿ˜ฐ", "๐Ÿคข", "๐Ÿ˜ฎ"}
event_set = {"๐ŸŽผ", "๐Ÿ‘", "๐Ÿ˜€", "๐Ÿ˜ญ", "๐Ÿคง", "๐Ÿ˜ท"}
def format_text_with_emojis(s):
sptk_dict = {sptk: s.count(sptk) for sptk in emoji_dict}
for sptk in emoji_dict:
s = s.replace(sptk, "")
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict.get(e, 0) > sptk_dict.get(emo, 0):
emo = e
s = (
"".join(event_dict[e] for e in event_dict if sptk_dict.get(e, 0) > 0)
+ s
+ emo_dict[emo]
)
for emoji in emo_set.union(event_set):
s = s.replace(f" {emoji}", emoji).replace(f"{emoji} ", emoji)
return s.strip()
def clean_and_emoji_annotate_speech(text):
def get_emoji(s, emoji_set):
return next((char for char in s if char in emoji_set), None)
# Replace special tags
text = text.replace("<|nospeech|><|Event_UNK|>", "โ“")
for lang, replacement in lang_dict.items():
text = text.replace(lang, replacement)
# Process each language segment
segments = [
format_text_with_emojis(segment.strip()) for segment in text.split("<|lang|>")
]
formatted_segments = []
prev_event = prev_emotion = None
for segment in segments:
if not segment:
continue
current_event = get_emoji(segment, event_set)
current_emotion = get_emoji(
segment, emo_set
) # Check for emotion emoji anywhere in the segment
if current_event is not None:
segment = segment[1:] if segment.startswith(current_event) else segment
# Preserve emotion emoji if it's different from the previous one
if current_emotion is not None and current_emotion != prev_emotion:
segment = segment.replace(current_emotion, "") + current_emotion
formatted_segments.append(segment.strip())
prev_event, prev_emotion = current_event, current_emotion
result = " ".join(formatted_segments).replace("The.", "").strip()
return result
def time_to_seconds(time_str):
h, m, s = time_str.split(":")
return round(int(h) * 3600 + int(m) * 60 + float(s), 9)
def parse_time(time_str):
# Remove 's' if present at the end of the string
time_str = time_str.rstrip("s")
# Split the time string into hours, minutes, and seconds
parts = time_str.split(":")
if len(parts) == 3:
h, m, s = parts
elif len(parts) == 2:
h = "0"
m, s = parts
else:
h = m = "0"
s = parts[0]
return int(h) * 3600 + int(m) * 60 + float(s)
def format_time(seconds, use_short_format=True):
if isinstance(seconds, datetime.timedelta):
seconds = seconds.total_seconds()
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(int(minutes), 60)
if use_short_format or (hours == 0 and minutes == 0):
return f"{seconds:05.3f}s"
elif hours == 0:
return f"{minutes:02d}:{seconds:06.3f}"
else:
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
def format_time_with_leading_zeros(seconds):
formatted = f"{seconds:06.3f}s"
print(f"Debug: Input seconds: {seconds}, Formatted output: {formatted}")
return formatted
def generate_diarization(audio_path):
# Get the Hugging Face token from the environment variable
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError(
"HF_TOKEN environment variable is not set. Please set it with your Hugging Face token."
)
# Initialize the audio processor
audio = Audio(sample_rate=16000, mono=True)
# Load the pretrained pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=hf_token
)
# Send pipeline to GPU if available
if torch.cuda.is_available():
pipeline.to(torch.device("cuda"))
# Set the correct path for the audio file
script_dir = os.path.dirname(os.path.abspath(__file__))
possible_paths = [
os.path.join(script_dir, "example", "mtr.mp3"),
os.path.join(script_dir, "..", "example", "mtr.mp3"),
os.path.join(script_dir, "mtr.mp3"),
"mtr.mp3",
audio_path, # Add the provided audio_path to the list of possible paths
]
file_path = None
for path in possible_paths:
if os.path.exists(path):
file_path = path
break
if file_path is None:
print("Debugging information:")
print(f"Current working directory: {os.getcwd()}")
print(f"Script directory: {script_dir}")
print("Attempted paths:")
for path in possible_paths:
print(f" {path}")
raise FileNotFoundError(
"Could not find the audio file. Please ensure it's in the correct location."
)
print(f"Using audio file: {file_path}")
# Process the audio file
waveform, sample_rate = audio(file_path)
# Create a dictionary with the audio information
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"}
# Run the diarization
output = pipeline(file)
# Save results in human-readable format
diarization_segments = []
txt_file = "mtr_dn.txt"
with open(txt_file, "w") as f:
for turn, _, speaker in output.itertracks(yield_label=True):
start_time = format_time(turn.start)
end_time = format_time(turn.end)
duration = format_time(turn.end - turn.start)
line = f"{start_time} - {end_time} ({duration}): {speaker}\n"
f.write(line)
print(line.strip())
diarization_segments.append(
(
parse_time(start_time),
parse_time(end_time),
parse_time(duration),
speaker,
)
)
print(f"\nHuman-readable diarization results saved to {txt_file}")
return diarization_segments
def process_audio(audio_path, language="yue", fs=16000):
# Generate diarization segments
diarization_segments = generate_diarization(audio_path)
# Load and preprocess audio
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != fs:
resampler = torchaudio.transforms.Resample(sample_rate, fs)
waveform = resampler(waveform)
input_wav = waveform.mean(0).numpy()
# Determine if the audio is less than one minute
total_duration = sum(duration for _, _, duration, _ in diarization_segments)
use_short_format = total_duration < 60
# Process the audio in chunks based on diarization segments
results = []
for start_time, end_time, duration, speaker in diarization_segments:
start_seconds = start_time
end_seconds = end_time
# Convert time to sample indices
start_sample = int(start_seconds * fs)
end_sample = int(end_seconds * fs)
chunk = input_wav[start_sample:end_sample]
try:
text = model.generate(
input=chunk,
cache={},
language=language,
use_itn=True,
batch_size_s=500,
merge_vad=True,
)
text = text[0]["text"]
# Print the text before clean_and_emoji_annotate_speech
print(f"Text before clean_and_emoji_annotate_speech: {text}")
text = clean_and_emoji_annotate_speech(text)
# Handle empty transcriptions
if not text.strip():
text = "[inaudible]"
results.append((speaker, start_time, end_time, duration, text))
except AssertionError as e:
if "choose a window size" in str(e):
print(
f"Warning: Audio segment too short to process. Skipping. Error: {e}"
)
results.append((speaker, start_time, end_time, duration, "[too short]"))
else:
raise
# Format the results
formatted_text = ""
for speaker, start, end, duration, text in results:
start_str = format_time_with_leading_zeros(start)
end_str = format_time_with_leading_zeros(end)
duration_str = format_time_with_leading_zeros(duration)
speaker_num = "1" if speaker == "SPEAKER_00" else "2"
line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}"
formatted_text += line + "\n"
print(f"Debug: Formatted line: {line}")
print("Debug: Full formatted text:")
print(formatted_text)
return formatted_text.strip()
if __name__ == "__main__":
audio_path = "example/mtr.mp3" # Replace with your audio file path
language = "yue" # Set language to Cantonese
result = process_audio(audio_path, language)
# Save the result to mtr.txt
output_path = "mtr.txt"
with open(output_path, "w", encoding="utf-8") as f:
f.write(result)
print(f"Diarization and transcription result has been saved to {output_path}")