#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from functools import lru_cache import json import logging from pathlib import Path import platform import shutil import tempfile import time from typing import Dict, Tuple import uuid import zipfile import gradio as gr import librosa from huggingface_hub import snapshot_download import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile import log from project_settings import environment, project_path, log_directory, time_zone_info from toolbox.os.command import Command from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad from toolbox.torchaudio.utils.visualization import process_speech_probs from toolbox.vad.utils import PostProcess log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info) logger = logging.getLogger("main") def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--examples_dir", # default=(project_path / "data").as_posix(), default=(project_path / "data/examples").as_posix(), type=str ) parser.add_argument( "--models_repo_id", default="qgyd2021/cc_vad", type=str ) parser.add_argument( "--trained_model_dir", default=(project_path / "trained_models").as_posix(), type=str ) parser.add_argument( "--hf_token", default=environment.get("hf_token"), type=str, ) parser.add_argument( "--server_port", default=environment.get("server_port", 7860), type=int ) args = parser.parse_args() return args def save_input_audio(sample_rate: int, signal: np.ndarray) -> str: if signal.dtype != np.int16: raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}") temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio" temp_audio_dir.mkdir(parents=True, exist_ok=True) filename = temp_audio_dir / f"{uuid.uuid4()}.wav" filename = filename.as_posix() wavfile.write( filename, sample_rate, signal ) return filename def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int): filename = save_input_audio(sample_rate, signal) signal, _ = librosa.load(filename, sr=target_sample_rate) signal = np.array(signal * (1 << 15), dtype=np.int16) return signal def shell(cmd: str): return Command.popen(cmd) def get_infer_cls_by_model_name(model_name: str): if model_name.__contains__("fsmn"): infer_cls = InferenceFSMNVadOnnx elif model_name.__contains__("silero"): infer_cls = InferenceSileroVad else: raise AssertionError return infer_cls vad_engines: Dict[str, dict] = None @lru_cache(maxsize=1) def load_vad_model(infer_cls, **kwargs): infer_engine = infer_cls(**kwargs) return infer_engine def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""): duration = np.arange(0, len(signal)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(duration, signal, color='b') plt.plot(duration, speech_probs, color='gray') plt.title(title) temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) plt.savefig(temp_file.name, bbox_inches="tight") plt.close() return temp_file.name def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, start_ring_rate: float = 0.5, end_ring_rate: float = 0.3, ring_max_length: int = 10, min_silence_length: int = 2, max_speech_length: int = 10000, min_speech_length: int = 10, engine: str = None, ): if audio_file_t is None and audio_microphone_t is None: raise gr.Error(f"audio file and microphone is null.") if audio_file_t is not None and audio_microphone_t is not None: gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.") audio_t: Tuple = audio_file_t or audio_microphone_t sample_rate, signal = audio_t if sample_rate != 8000: signal = convert_sample_rate(signal, sample_rate, 8000) sample_rate = 8000 audio_duration = signal.shape[-1] // sample_rate audio = np.array(signal / (1 << 15), dtype=np.float32) infer_engine_param = vad_engines.get(engine) if infer_engine_param is None: raise gr.Error(f"invalid denoise engine: {engine}.") try: infer_cls = infer_engine_param["infer_cls"] kwargs = infer_engine_param["kwargs"] infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs) begin = time.time() vad_info = infer_engine.infer(audio) time_cost = time.time() - begin probs = vad_info["probs"] lsnr = vad_info["lsnr"] # lsnr = lsnr / np.max(np.abs(lsnr)) lsnr = lsnr / 30 frame_step = infer_engine.config.hop_size # post process vad_post_process = PostProcess( start_ring_rate=start_ring_rate, end_ring_rate=end_ring_rate, ring_max_length=ring_max_length, min_silence_length=min_silence_length, max_speech_length=max_speech_length, min_speech_length=min_speech_length ) vad_segments = vad_post_process.get_vad_segments(probs) vad_flags = vad_post_process.get_vad_flags(probs, vad_segments) # vad_image vad_ = process_speech_probs(audio, vad_flags, frame_step) vad_image = generate_image(audio, vad_) # probs_image probs_ = process_speech_probs(audio, probs, frame_step) probs_image = generate_image(audio, probs_) # lsnr_image lsnr_ = process_speech_probs(audio, lsnr, frame_step) lsnr_image = generate_image(audio, lsnr_) # vad segment vad_segments = [ [ v[0] * frame_step / sample_rate, v[1] * frame_step / sample_rate ] for v in vad_segments ] # message rtf = time_cost / audio_duration info = { "vad_segments": vad_segments, "time_cost": round(time_cost, 4), "duration": round(audio_duration, 4), "rtf": round(rtf, 4) } message = json.dumps(info, ensure_ascii=False, indent=4) except Exception as e: raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.") return vad_image, probs_image, lsnr_image, message def main(): args = get_args() examples_dir = Path(args.examples_dir) trained_model_dir = Path(args.trained_model_dir) # download models if not trained_model_dir.exists(): trained_model_dir.mkdir(parents=True, exist_ok=True) _ = snapshot_download( repo_id=args.models_repo_id, local_dir=trained_model_dir.as_posix(), token=args.hf_token, ) # engines global vad_engines vad_engines = { filename.stem: { "infer_cls": get_infer_cls_by_model_name(filename.stem), "kwargs": { "pretrained_model_path_or_zip_file": filename.as_posix() } } for filename in (project_path / "trained_models").glob("*.zip") if filename.name not in ( "cnn-vad-by-webrtcvad-nx-dns3.zip", "fsmn-vad-by-webrtcvad-nx-dns3.zip", "examples.zip", "sound-2-ch32.zip", "sound-3-ch32.zip", "sound-4-ch32.zip", "sound-8-ch32.zip", ) } # choices vad_engine_choices = list(vad_engines.keys()) # examples if not examples_dir.exists(): example_zip_file = trained_model_dir / "examples.zip" with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: out_root = examples_dir if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) # examples examples = list() for filename in examples_dir.glob("**/*.wav"): examples.append([ filename.as_posix(), None, vad_engine_choices[0], ]) # ui with gr.Blocks() as blocks: gr.Markdown(value="vad.") with gr.Tabs(): with gr.TabItem("vad"): with gr.Row(): with gr.Column(variant="panel", scale=5): with gr.Tabs(): with gr.TabItem("file"): vad_audio_file = gr.Audio(label="audio") with gr.TabItem("microphone"): vad_audio_microphone = gr.Audio(sources="microphone", label="audio") with gr.Row(): vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate") vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate") with gr.Row(): vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)") vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)") with gr.Row(): vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)") vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)") vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine") vad_button = gr.Button(variant="primary") with gr.Column(variant="panel", scale=5): vad_vad_image = gr.Image(label="vad") vad_prob_image = gr.Image(label="prob") vad_lsnr_image = gr.Image(label="lsnr") vad_message = gr.Textbox(lines=1, max_lines=20, label="message") vad_button.click( when_click_vad_button, inputs=[ vad_audio_file, vad_audio_microphone, vad_start_ring_rate, vad_end_ring_rate, vad_ring_max_length, vad_min_silence_length, vad_max_speech_length, vad_min_speech_length, vad_engine, ], outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message], ) gr.Examples( examples=examples, inputs=[vad_audio_file, vad_audio_microphone, vad_engine], outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message], fn=when_click_vad_button, # cache_examples=True, # cache_mode="lazy", ) with gr.TabItem("shell"): shell_text = gr.Textbox(label="cmd") shell_button = gr.Button("run") shell_output = gr.Textbox(label="output") shell_button.click( shell, inputs=[shell_text,], outputs=[shell_output], ) # http://127.0.0.1:7866/ # http://10.75.27.247:7866/ blocks.queue().launch( # share=True, share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=args.server_port, show_error=True ) return if __name__ == "__main__": main()