Spaces:
Running
Running
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| docker build -t cc_audio_8:v20250828_1343 . | |
| docker stop cc_audio_8_7864 && docker rm cc_audio_8_7864 | |
| docker run -itd \ | |
| --name cc_audio_8_7864 \ | |
| --restart=always \ | |
| --network host \ | |
| -e server_port=7865 \ | |
| cc_audio_8:v20250828_1343 /bin/bash | |
| docker run -itd \ | |
| --name cc_audio_8_7864 \ | |
| --network host \ | |
| --gpus all \ | |
| --privileged \ | |
| --ipc=host \ | |
| python:3.12 /bin/bash | |
| nohup python3 main.py --server_port 7864 --hf_token hf_coRVvzwA****jLmZHwJobEX & | |
| """ | |
| import argparse | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import platform | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| from typing import Tuple | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import numpy as np | |
| import torch | |
| from project_settings import environment, project_path | |
| from toolbox.torch.utils.data.vocabulary import Vocabulary | |
| from tabs.cls_tab import get_cls_tab | |
| from tabs.split_tab import get_split_tab | |
| from tabs.event_tab import get_event_tab | |
| from tabs.shell_tab import get_shell_tab | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--examples_dir", | |
| # default=(project_path / "data").as_posix(), | |
| default=(project_path / "data/examples").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--models_repo_id", | |
| default="qgyd2021/cc_audio_8", | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--trained_model_dir", | |
| default=(project_path / "trained_models").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--hf_token", | |
| default=environment.get("hf_token"), | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--server_port", | |
| default=environment.get("server_port", 7860), | |
| type=int | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def load_model(model_file: Path): | |
| with zipfile.ZipFile(model_file, "r") as f_zip: | |
| out_root = Path(tempfile.gettempdir()) / "cc_audio_8" | |
| if out_root.exists(): | |
| shutil.rmtree(out_root.as_posix()) | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| tgt_path = out_root / model_file.stem | |
| jit_model_file = tgt_path / "trace_model.zip" | |
| vocab_path = tgt_path / "vocabulary" | |
| vocabulary = Vocabulary.from_files(vocab_path.as_posix()) | |
| with open(jit_model_file.as_posix(), "rb") as f: | |
| model = torch.jit.load(f) | |
| model.eval() | |
| shutil.rmtree(tgt_path) | |
| d = { | |
| "model": model, | |
| "vocabulary": vocabulary | |
| } | |
| return d | |
| def main(): | |
| args = get_args() | |
| examples_dir = Path(args.examples_dir) | |
| trained_model_dir = Path(args.trained_model_dir) | |
| # download models | |
| if not trained_model_dir.exists(): | |
| trained_model_dir.mkdir(parents=True, exist_ok=True) | |
| _ = snapshot_download( | |
| repo_id=args.models_repo_id, | |
| local_dir=trained_model_dir.as_posix(), | |
| token=args.hf_token, | |
| ) | |
| # examples zip | |
| if not examples_dir.exists(): | |
| example_zip_file = trained_model_dir / "examples.zip" | |
| with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: | |
| out_root = examples_dir | |
| if out_root.exists(): | |
| shutil.rmtree(out_root.as_posix()) | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| # ui | |
| with gr.Blocks() as blocks: | |
| with gr.Tabs(): | |
| _ = get_cls_tab( | |
| examples_dir=args.examples_dir, | |
| trained_model_dir=args.trained_model_dir, | |
| ) | |
| _ = get_event_tab( | |
| examples_dir=args.examples_dir, | |
| trained_model_dir=args.trained_model_dir, | |
| ) | |
| _ = get_split_tab( | |
| examples_dir=args.examples_dir, | |
| trained_model_dir=args.trained_model_dir, | |
| ) | |
| _ = get_shell_tab() | |
| # http://127.0.0.1:7864/ | |
| blocks.queue().launch( | |
| share=False if platform.system() == "Windows" else False, | |
| server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
| server_port=args.server_port | |
| ) | |
| return | |
| if __name__ == "__main__": | |
| main() | |