|
from matplotlib import pyplot as plt |
|
from accelerate import Accelerator |
|
from zero import zero |
|
import gradio as gr |
|
from typing import Tuple |
|
import os |
|
from os import path |
|
from utils import plot_spec |
|
import librosa |
|
from hashlib import md5 |
|
from demucs.separate import main as demucs |
|
from pyannote.audio import Pipeline |
|
from json import dumps, loads |
|
import shutil |
|
import zipfile |
|
|
|
accelerator = Accelerator() |
|
device = accelerator.device |
|
print(f"Running on {device}") |
|
|
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", use_auth_token=os.environ["HF_TOKEN"] |
|
) |
|
pipeline.to(device) |
|
|
|
|
|
tasks = [] |
|
os.makedirs("task", exist_ok=True) |
|
for task in os.listdir("task"): |
|
if path.isdir(path.join("task", task)): |
|
tasks.append(task) |
|
|
|
|
|
def gen_task_id(location: str): |
|
|
|
video = open(location, "rb").read() |
|
return md5(video).hexdigest() |
|
|
|
|
|
def extract_audio(video: str) -> Tuple[str, str]: |
|
task_id = gen_task_id(video) |
|
os.makedirs(path.join("task", task_id), exist_ok=True) |
|
|
|
videodest = path.join("task", task_id, "video.mp4") |
|
if not path.exists(videodest): |
|
shutil.copy(video, videodest) |
|
|
|
wav48k = path.join("task", task_id, "extracted_48k.wav") |
|
if not path.exists(wav48k): |
|
os.system( |
|
f"ffmpeg -i {videodest} -vn -ar 48000 task/{task_id}/extracted_48k.wav" |
|
) |
|
|
|
return (task_id, wav48k) |
|
|
|
|
|
def extract_audio_post(task_id: str) -> str: |
|
wav48k = path.join("task", task_id, "extracted_48k.wav") |
|
if not path.exists(wav48k): |
|
raise gr.Error("Audio file not found") |
|
|
|
spec = path.join("task", task_id, "extracted_48k.png") |
|
if not path.exists(spec): |
|
y, sr = librosa.load(wav48k, sr=16000) |
|
fig = plot_spec(y, sr) |
|
fig.savefig(path.join("task", task_id, "extracted_48k.png")) |
|
plt.close(fig) |
|
|
|
return spec |
|
|
|
|
|
@zero(duration=60 * 2) |
|
def extract_vocals(task_id: str) -> str: |
|
audio = path.join("task", task_id, "extracted_48k.wav") |
|
if not path.exists(audio): |
|
raise gr.Error("Audio file not found") |
|
|
|
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") |
|
|
|
if not path.exists(vocals): |
|
demucs( |
|
[ |
|
"-d", |
|
str(device), |
|
"-n", |
|
"htdemucs", |
|
"--two-stems", |
|
"vocals", |
|
"-o", |
|
path.join("task", task_id), |
|
audio, |
|
] |
|
) |
|
|
|
return vocals |
|
|
|
|
|
def extract_vocals_post(task_id: str) -> str: |
|
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") |
|
if not path.exists(vocals): |
|
raise gr.Error("Vocals file not found") |
|
|
|
spec = path.join("task", task_id, "vocals.png") |
|
if not path.exists(spec): |
|
y, sr = librosa.load(vocals, sr=16000) |
|
fig = plot_spec(y, sr) |
|
fig.savefig(path.join("task", task_id, "vocals.png")) |
|
plt.close(fig) |
|
|
|
return spec |
|
|
|
|
|
@zero(duration=60 * 2) |
|
def diarize_audio(task_id: str): |
|
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") |
|
if not path.exists(vocals): |
|
raise gr.Error("Vocals file not found") |
|
|
|
diarization_json = path.join("task", task_id, "diarization.json") |
|
if not path.exists(diarization_json): |
|
result = pipeline(vocals) |
|
with open(diarization_json, "w") as f: |
|
diarization = [] |
|
for turn, _, speaker in result.itertracks(yield_label=True): |
|
diarization.append( |
|
{ |
|
"speaker": speaker, |
|
"start": turn.start, |
|
"end": turn.end, |
|
"duration": turn.duration, |
|
} |
|
) |
|
f.write(dumps(diarization)) |
|
with open(diarization_json, "r") as f: |
|
diarization = loads(f.read()) |
|
|
|
filtered_json = path.join("task", task_id, "filtered_diarization.json") |
|
if not path.exists(filtered_json): |
|
|
|
filtered_segments = {} |
|
for turn in diarization: |
|
speaker = turn["speaker"] |
|
if turn["duration"] >= 2.0: |
|
if speaker not in filtered_segments: |
|
filtered_segments[speaker] = [] |
|
filtered_segments[speaker].append(turn) |
|
|
|
|
|
filtered_segments = { |
|
speaker: segments |
|
for speaker, segments in filtered_segments.items() |
|
if sum(segment["duration"] for segment in segments) >= 60 |
|
} |
|
|
|
with open(filtered_json, "w") as f: |
|
f.write(dumps(filtered_segments)) |
|
with open(filtered_json, "r") as f: |
|
filtered_segments = loads(f.read()) |
|
|
|
return filtered_segments |
|
|
|
|
|
def generate_clips(task_id: str, speaker: str) -> Tuple[str, str, str]: |
|
video = path.join("task", task_id, "video.mp4") |
|
if not path.exists(video): |
|
raise gr.Error("Video file not found") |
|
|
|
filtered_json = path.join("task", task_id, "filtered_diarization.json") |
|
if not path.exists(filtered_json): |
|
raise gr.Error("Diarization not found") |
|
|
|
with open(filtered_json, "r") as f: |
|
filtered_segments = loads(f.read()) |
|
|
|
if speaker not in filtered_segments: |
|
raise gr.Error("Speaker not found") |
|
|
|
mp4 = path.join("task", task_id, f"{speaker}.mp4") |
|
if not path.exists(mp4): |
|
cmd = f'ffmpeg -i {video} -filter_complex "' |
|
for i, segment in enumerate(filtered_segments[speaker]): |
|
start = segment["start"] |
|
end = segment["end"] |
|
cmd += f"[0:v]trim=start={start}:end={end},setpts=PTS-STARTPTS[v{i}];" |
|
cmd += f"[0:a]atrim=start={start}:end={end},asetpts=PTS-STARTPTS[a{i}];" |
|
for i in range(len(filtered_segments[speaker])): |
|
cmd += f"[v{i}][a{i}]" |
|
cmd += f'concat=n={len(filtered_segments[speaker])}:v=1:a=1[outv][outa]" -map [outv] -map [outa] -y {mp4}' |
|
os.system(cmd) |
|
|
|
segments = path.join("task", task_id, f"{speaker}") |
|
if not path.exists(segments): |
|
os.makedirs(segments) |
|
for i, segment in enumerate(filtered_segments[speaker]): |
|
start = segment["start"] |
|
end = segment["end"] |
|
name = path.join(segments, f"{i}_{start:.2f}_{end:.2f}.wav") |
|
cmd = f"ffmpeg -i {video} -ss {start} -to {end} -f wav {name}" |
|
os.system(cmd) |
|
|
|
segments_zip = path.join("task", task_id, f"{speaker}.zip") |
|
if not path.exists(segments_zip): |
|
with zipfile.ZipFile(segments_zip, "w") as zipf: |
|
files = [f for f in os.listdir(segments) if f.endswith(".wav")] |
|
for file in files: |
|
zipf.write(path.join(segments, file), file) |
|
|
|
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") |
|
vocal_segments = path.join("task", task_id, f"{speaker}_vocals") |
|
if not path.exists(vocal_segments): |
|
os.makedirs(vocal_segments) |
|
for i, segment in enumerate(filtered_segments[speaker]): |
|
start = segment["start"] |
|
end = segment["end"] |
|
name = path.join(vocal_segments, f"{i}_{start:.2f}_{end:.2f}.wav") |
|
cmd = f"ffmpeg -i {vocals} -ss {start} -to {end} -f wav {name}" |
|
os.system(cmd) |
|
|
|
vocal_segments_zip = path.join("task", task_id, f"{speaker}_vocals.zip") |
|
if not path.exists(vocal_segments_zip): |
|
with zipfile.ZipFile(vocal_segments_zip, "w") as zipf: |
|
files = [f for f in os.listdir(vocal_segments) if f.endswith(".wav")] |
|
for file in files: |
|
zipf.write(path.join(vocal_segments, file), file) |
|
|
|
return mp4, segments_zip, vocal_segments_zip |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Video Speaker Diarization") |
|
|
|
gr.Markdown( |
|
""" |
|
First, upload a video file. And let us do some inspection on the audio of the video. |
|
""" |
|
) |
|
original_video = gr.Video(label="Upload a video", show_download_button=True) |
|
preprocess_btn = gr.Button(value="Pre Process", variant="primary") |
|
preprocess_btn_label = gr.Markdown("Press the button!") |
|
task_id = gr.Textbox(label="Task ID", visible=False) |
|
|
|
with gr.Column(visible=False) as preprocess_output: |
|
gr.Markdown( |
|
""" |
|
Now you can see the spectrogram of the extracted audio. |
|
|
|
Next, let's remove the background music from the audio. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
extracted_audio = gr.Audio(label="Extracted Audio", type="filepath") |
|
extracted_audio_spec = gr.Image(label="Extracted Audio Spectrogram") |
|
|
|
extract_vocals_btn = gr.Button( |
|
value="Remove Background Music", variant="primary" |
|
) |
|
extract_vocals_btn_label = gr.Markdown("Press the button!") |
|
|
|
with gr.Column(visible=False) as extract_vocals_output: |
|
with gr.Row(): |
|
vocals = gr.Audio(label="Vocals", type="filepath") |
|
vocals_spec = gr.Image(label="Vocals Spectrogram") |
|
|
|
diarize_btn = gr.Button(value="Diarize", variant="primary") |
|
diarize_btn_label = gr.Markdown("Press the button!") |
|
|
|
with gr.Column(visible=False) as diarize_output: |
|
gr.Markdown( |
|
""" |
|
Now you can select the speaker from the dropdown below to generate the clips of the speaker. |
|
""" |
|
) |
|
with gr.Row(): |
|
speaker_select = gr.Dropdown(label="Speaker", choices=[]) |
|
diarization_result = gr.Markdown("", height=400) |
|
|
|
generate_clips_btn = gr.Button(value="Generate Clips", variant="primary") |
|
generate_clips_btn_label = gr.Markdown("Press the button!") |
|
|
|
with gr.Column(visible=False) as generate_clips_output: |
|
speaker_clip = gr.Video(label="Speaker Clip") |
|
speaker_clip_zip = gr.File(label="Download Audio Segments") |
|
speaker_clip_vocal_zip = gr.File(label="Download Vocal Segments") |
|
|
|
def preprocess(video: str): |
|
task_id_val, extracted_audio_val = extract_audio(video) |
|
return { |
|
preprocess_output: gr.Column(visible=True), |
|
task_id: task_id_val, |
|
extracted_audio: extracted_audio_val, |
|
preprocess_btn_label: gr.Markdown("", visible=False), |
|
} |
|
|
|
preprocess_btn.click( |
|
fn=preprocess, |
|
inputs=[original_video], |
|
outputs=[ |
|
preprocess_output, |
|
task_id, |
|
extracted_audio, |
|
preprocess_btn_label, |
|
], |
|
api_name="preprocess", |
|
).success( |
|
fn=extract_audio_post, |
|
inputs=[task_id], |
|
outputs=[extracted_audio_spec], |
|
api_name="preprocess-post", |
|
) |
|
|
|
def extract_vocals_fn(task_id: str): |
|
vocals_val = extract_vocals(task_id) |
|
return { |
|
extract_vocals_output: gr.Column(visible=True), |
|
vocals: vocals_val, |
|
extract_vocals_btn_label: gr.Markdown("", visible=False), |
|
} |
|
|
|
extract_vocals_btn.click( |
|
fn=extract_vocals_fn, |
|
inputs=[task_id], |
|
outputs=[extract_vocals_output, vocals, extract_vocals_btn_label], |
|
api_name="extract-vocals", |
|
).success( |
|
fn=extract_vocals_post, |
|
inputs=[task_id], |
|
outputs=[vocals_spec], |
|
api_name="extract-vocals-post", |
|
) |
|
|
|
def diarize_fn(task_id: str): |
|
filtered_segments = diarize_audio(task_id) |
|
choices = [] |
|
for speaker in filtered_segments: |
|
total = sum(segment["duration"] for segment in filtered_segments[speaker]) |
|
choices.append((f"{speaker} ({total:.2f}s)", speaker)) |
|
|
|
info = "" |
|
for speaker, segments in filtered_segments.items(): |
|
total = sum(segment["duration"] for segment in segments) |
|
info += f"### Speaker {speaker}: ({total:.2f}s)\n" |
|
for segment in segments: |
|
start = segment["start"] |
|
end = segment["end"] |
|
info += f"- {start:.2f} - {end:.2f} ({segment['duration']:.2f}s)\n" |
|
return { |
|
diarize_output: gr.Column(visible=True), |
|
speaker_select: gr.Dropdown(label="Speaker", choices=choices), |
|
diarization_result: gr.Markdown(info), |
|
diarize_btn_label: gr.Markdown("", visible=False), |
|
} |
|
|
|
diarize_btn.click( |
|
fn=diarize_fn, |
|
inputs=[task_id], |
|
outputs=[diarize_output, speaker_select, diarization_result, diarize_btn_label], |
|
api_name="diarize", |
|
) |
|
|
|
def generate_clips_fn(task_id: str, speaker: str): |
|
speaker_clip_val, zip_val, vocal_zip_val = generate_clips(task_id, speaker) |
|
return { |
|
generate_clips_output: gr.Column(visible=True), |
|
speaker_clip: speaker_clip_val, |
|
speaker_clip_zip: zip_val, |
|
speaker_clip_vocal_zip: vocal_zip_val, |
|
generate_clips_btn_label: gr.Markdown("", visible=False), |
|
} |
|
|
|
generate_clips_btn.click( |
|
fn=generate_clips_fn, |
|
inputs=[task_id, speaker_select], |
|
outputs=[ |
|
generate_clips_output, |
|
speaker_clip, |
|
speaker_clip_zip, |
|
speaker_clip_vocal_zip, |
|
generate_clips_btn_label, |
|
], |
|
api_name="generate_clips", |
|
) |
|
|
|
app.launch() |
|
|