cc_vad / main.py
HoneyTian's picture
update
8051e41
raw
history blame
10.4 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
import json
import logging
from pathlib import Path
import platform
import shutil
import tempfile
import time
from typing import Dict, Tuple
import zipfile
import gradio as gr
from huggingface_hub import snapshot_download
import matplotlib.pyplot as plt
import numpy as np
import log
from project_settings import environment, project_path, log_directory, time_zone_info
from toolbox.os.command import Command
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
from toolbox.torchaudio.utils.visualization import process_speech_probs
from toolbox.vad.utils import PostProcess
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
logger = logging.getLogger("main")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/cc_vad",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
def shell(cmd: str):
return Command.popen(cmd)
def get_infer_cls_by_model_name(model_name: str):
if model_name.__contains__("fsmn"):
infer_cls = InferenceFSMNVadOnnx
elif model_name.__contains__("silero"):
infer_cls = InferenceSileroVad
else:
raise AssertionError
return infer_cls
vad_engines: Dict[str, dict] = None
@lru_cache(maxsize=1)
def load_vad_model(infer_cls, **kwargs):
infer_engine = infer_cls(**kwargs)
return infer_engine
def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""):
duration = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(duration, signal, color='b')
plt.plot(duration, speech_probs, color='gray')
plt.title(title)
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
plt.savefig(temp_file.name, bbox_inches="tight")
plt.close()
return temp_file.name
def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
min_silence_length: int = 2,
max_speech_length: int = 10000, min_speech_length: int = 10,
engine: str = None,
):
if audio_file_t is None and audio_microphone_t is None:
raise gr.Error(f"audio file and microphone is null.")
if audio_file_t is not None and audio_microphone_t is not None:
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
audio_t: Tuple = audio_file_t or audio_microphone_t
sample_rate, signal = audio_t
audio_duration = signal.shape[-1] // 8000
audio = np.array(signal / (1 << 15), dtype=np.float32)
infer_engine_param = vad_engines.get(engine)
if infer_engine_param is None:
raise gr.Error(f"invalid denoise engine: {engine}.")
try:
infer_cls = infer_engine_param["infer_cls"]
kwargs = infer_engine_param["kwargs"]
infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs)
begin = time.time()
vad_info = infer_engine.infer(audio)
time_cost = time.time() - begin
fpr = time_cost / audio_duration
info = {
"time_cost": round(time_cost, 4),
"audio_duration": round(audio_duration, 4),
"fpr": round(fpr, 4)
}
message = json.dumps(info, ensure_ascii=False, indent=4)
probs = vad_info["probs"]
lsnr = vad_info["lsnr"]
# lsnr = lsnr / np.max(np.abs(lsnr))
lsnr = lsnr / 30
frame_step = infer_engine.config.hop_size
probs_ = process_speech_probs(audio, probs, frame_step)
probs_image = generate_image(audio, probs_)
lsnr_ = process_speech_probs(audio, lsnr, frame_step)
lsnr_image = generate_image(audio, lsnr_)
# post process
vad_post_process = PostProcess(
start_ring_rate=start_ring_rate,
end_ring_rate=end_ring_rate,
min_silence_length=min_silence_length,
max_speech_length=max_speech_length,
min_speech_length=min_speech_length
)
vad = vad_post_process.post_process(probs)
vad_ = process_speech_probs(audio, vad, frame_step)
vad_image = generate_image(audio, vad_)
except Exception as e:
raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
return vad_image, probs_image, lsnr_image, message
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# engines
global vad_engines
vad_engines = {
filename.stem: {
"infer_cls": get_infer_cls_by_model_name(filename.stem),
"kwargs": {
"pretrained_model_path_or_zip_file": filename.as_posix()
}
}
for filename in (project_path / "trained_models").glob("*.zip")
if filename.name not in (
"cnn-vad-by-webrtcvad-nx-dns3.zip",
"fsmn-vad-by-webrtcvad-nx-dns3.zip",
"examples.zip",
"sound-2-ch32.zip",
"sound-3-ch32.zip",
"sound-4-ch32.zip",
"sound-8-ch32.zip",
)
}
# choices
vad_engine_choices = list(vad_engines.keys())
# examples
if not examples_dir.exists():
example_zip_file = trained_model_dir / "examples.zip"
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
out_root = examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# examples
examples = list()
for filename in examples_dir.glob("**/*.wav"):
examples.append([
filename.as_posix(),
None,
vad_engine_choices[0],
])
# ui
with gr.Blocks() as blocks:
gr.Markdown(value="vad.")
with gr.Tabs():
with gr.TabItem("vad"):
with gr.Row():
with gr.Column(variant="panel", scale=5):
with gr.Tabs():
with gr.TabItem("file"):
vad_audio_file = gr.Audio(label="audio")
with gr.TabItem("microphone"):
vad_audio_microphone = gr.Audio(sources="microphone", label="audio")
with gr.Row():
vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
vad_min_silence_length = gr.Number(value=30, label="min_silence_length")
with gr.Row():
vad_max_speech_length = gr.Number(value=100000, label="max_speech_length")
vad_min_speech_length = gr.Number(value=15, label="min_speech_length")
vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
vad_button = gr.Button(variant="primary")
with gr.Column(variant="panel", scale=5):
vad_vad_image = gr.Image(label="vad")
vad_prob_image = gr.Image(label="prob")
vad_lsnr_image = gr.Image(label="lsnr")
vad_message = gr.Textbox(lines=1, max_lines=20, label="message")
vad_button.click(
when_click_vad_button,
inputs=[
vad_audio_file, vad_audio_microphone,
vad_start_ring_rate, vad_end_ring_rate,
vad_min_silence_length,
vad_max_speech_length, vad_min_speech_length,
vad_engine,
],
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
)
gr.Examples(
examples=examples,
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
fn=when_click_vad_button,
# cache_examples=True,
# cache_mode="lazy",
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[shell_text,],
outputs=[shell_output],
)
# http://127.0.0.1:7866/
# http://10.75.27.247:7866/
blocks.queue().launch(
# share=True,
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()