#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from functools import lru_cache import json import logging from pathlib import Path import platform import shutil import tempfile import time from typing import Dict, Tuple import zipfile import gradio as gr from huggingface_hub import snapshot_download import matplotlib.pyplot as plt import numpy as np import log from project_settings import environment, project_path, log_directory, time_zone_info from toolbox.os.command import Command from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad from toolbox.torchaudio.utils.visualization import process_speech_probs from toolbox.vad.utils import PostProcess log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info) logger = logging.getLogger("main") def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--examples_dir", # default=(project_path / "data").as_posix(), default=(project_path / "data/examples").as_posix(), type=str ) parser.add_argument( "--models_repo_id", default="qgyd2021/cc_vad", type=str ) parser.add_argument( "--trained_model_dir", default=(project_path / "trained_models").as_posix(), type=str ) parser.add_argument( "--hf_token", default=environment.get("hf_token"), type=str, ) parser.add_argument( "--server_port", default=environment.get("server_port", 7860), type=int ) args = parser.parse_args() return args def shell(cmd: str): return Command.popen(cmd) def get_infer_cls_by_model_name(model_name: str): if model_name.__contains__("fsmn"): infer_cls = InferenceFSMNVadOnnx elif model_name.__contains__("silero"): infer_cls = InferenceSileroVad else: raise AssertionError return infer_cls vad_engines: Dict[str, dict] = None @lru_cache(maxsize=1) def load_vad_model(infer_cls, **kwargs): infer_engine = infer_cls(**kwargs) return infer_engine def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""): duration = np.arange(0, len(signal)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(duration, signal, color='b') plt.plot(duration, speech_probs, color='gray') plt.title(title) temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) plt.savefig(temp_file.name, bbox_inches="tight") plt.close() return temp_file.name def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, start_ring_rate: float = 0.5, end_ring_rate: float = 0.3, min_silence_length: int = 2, max_speech_length: int = 10000, min_speech_length: int = 10, engine: str = None, ): if audio_file_t is None and audio_microphone_t is None: raise gr.Error(f"audio file and microphone is null.") if audio_file_t is not None and audio_microphone_t is not None: gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.") audio_t: Tuple = audio_file_t or audio_microphone_t sample_rate, signal = audio_t audio_duration = signal.shape[-1] // 8000 audio = np.array(signal / (1 << 15), dtype=np.float32) infer_engine_param = vad_engines.get(engine) if infer_engine_param is None: raise gr.Error(f"invalid denoise engine: {engine}.") try: infer_cls = infer_engine_param["infer_cls"] kwargs = infer_engine_param["kwargs"] infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs) begin = time.time() vad_info = infer_engine.infer(audio) time_cost = time.time() - begin fpr = time_cost / audio_duration info = { "time_cost": round(time_cost, 4), "audio_duration": round(audio_duration, 4), "fpr": round(fpr, 4) } message = json.dumps(info, ensure_ascii=False, indent=4) probs = vad_info["probs"] lsnr = vad_info["lsnr"] # lsnr = lsnr / np.max(np.abs(lsnr)) lsnr = lsnr / 30 frame_step = infer_engine.config.hop_size probs_ = process_speech_probs(audio, probs, frame_step) probs_image = generate_image(audio, probs_) lsnr_ = process_speech_probs(audio, lsnr, frame_step) lsnr_image = generate_image(audio, lsnr_) # post process vad_post_process = PostProcess( start_ring_rate=start_ring_rate, end_ring_rate=end_ring_rate, min_silence_length=min_silence_length, max_speech_length=max_speech_length, min_speech_length=min_speech_length ) vad = vad_post_process.post_process(probs) vad_ = process_speech_probs(audio, vad, frame_step) vad_image = generate_image(audio, vad_) except Exception as e: raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.") return vad_image, probs_image, lsnr_image, message def main(): args = get_args() examples_dir = Path(args.examples_dir) trained_model_dir = Path(args.trained_model_dir) # download models if not trained_model_dir.exists(): trained_model_dir.mkdir(parents=True, exist_ok=True) _ = snapshot_download( repo_id=args.models_repo_id, local_dir=trained_model_dir.as_posix(), token=args.hf_token, ) # engines global vad_engines vad_engines = { filename.stem: { "infer_cls": get_infer_cls_by_model_name(filename.stem), "kwargs": { "pretrained_model_path_or_zip_file": filename.as_posix() } } for filename in (project_path / "trained_models").glob("*.zip") if filename.name not in ( "cnn-vad-by-webrtcvad-nx-dns3.zip", "fsmn-vad-by-webrtcvad-nx-dns3.zip", "examples.zip", "sound-2-ch32.zip", "sound-3-ch32.zip", "sound-4-ch32.zip", "sound-8-ch32.zip", ) } # choices vad_engine_choices = list(vad_engines.keys()) # examples if not examples_dir.exists(): example_zip_file = trained_model_dir / "examples.zip" with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: out_root = examples_dir if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) # examples examples = list() for filename in examples_dir.glob("**/*.wav"): examples.append([ filename.as_posix(), None, vad_engine_choices[0], ]) # ui with gr.Blocks() as blocks: gr.Markdown(value="vad.") with gr.Tabs(): with gr.TabItem("vad"): with gr.Row(): with gr.Column(variant="panel", scale=5): with gr.Tabs(): with gr.TabItem("file"): vad_audio_file = gr.Audio(label="audio") with gr.TabItem("microphone"): vad_audio_microphone = gr.Audio(sources="microphone", label="audio") with gr.Row(): vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate") vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate") vad_min_silence_length = gr.Number(value=30, label="min_silence_length") with gr.Row(): vad_max_speech_length = gr.Number(value=100000, label="max_speech_length") vad_min_speech_length = gr.Number(value=15, label="min_speech_length") vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine") vad_button = gr.Button(variant="primary") with gr.Column(variant="panel", scale=5): vad_vad_image = gr.Image(label="vad") vad_prob_image = gr.Image(label="prob") vad_lsnr_image = gr.Image(label="lsnr") vad_message = gr.Textbox(lines=1, max_lines=20, label="message") vad_button.click( when_click_vad_button, inputs=[ vad_audio_file, vad_audio_microphone, vad_start_ring_rate, vad_end_ring_rate, vad_min_silence_length, vad_max_speech_length, vad_min_speech_length, vad_engine, ], outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message], ) gr.Examples( examples=examples, inputs=[vad_audio_file, vad_audio_microphone, vad_engine], outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message], fn=when_click_vad_button, # cache_examples=True, # cache_mode="lazy", ) with gr.TabItem("shell"): shell_text = gr.Textbox(label="cmd") shell_button = gr.Button("run") shell_output = gr.Textbox(label="output") shell_button.click( shell, inputs=[shell_text,], outputs=[shell_output], ) # http://127.0.0.1:7866/ # http://10.75.27.247:7866/ blocks.queue().launch( # share=True, share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=args.server_port ) return if __name__ == "__main__": main()