cc_audio_8 / main.py
HoneyTian's picture
pdate
459dab4
raw
history blame
4.26 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
docker build -t cc_audio_8:v20250828_1343 .
docker stop cc_audio_8_7864 && docker rm cc_audio_8_7864
docker run -itd \
--name cc_audio_8_7864 \
--restart=always \
--network host \
-e server_port=7865 \
cc_audio_8:v20250828_1343 /bin/bash
docker run -itd \
--name cc_audio_8_7864 \
--network host \
--gpus all \
--privileged \
--ipc=host \
python:3.12 /bin/bash
nohup python3 main.py --server_port 7864 --hf_token hf_coRVvzwA****jLmZHwJobEX &
"""
import argparse
from functools import lru_cache
from pathlib import Path
import platform
import shutil
import tempfile
import zipfile
from typing import Tuple
import gradio as gr
from huggingface_hub import snapshot_download
import numpy as np
import torch
from project_settings import environment, project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
from tabs.cls_tab import get_cls_tab
from tabs.split_tab import get_split_tab
from tabs.event_tab import get_event_tab
from tabs.shell_tab import get_shell_tab
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/cc_audio_8",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
@lru_cache(maxsize=100)
def load_model(model_file: Path):
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
jit_model_file = tgt_path / "trace_model.zip"
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
with open(jit_model_file.as_posix(), "rb") as f:
model = torch.jit.load(f)
model.eval()
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# examples zip
if not examples_dir.exists():
example_zip_file = trained_model_dir / "examples.zip"
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
out_root = examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# ui
with gr.Blocks() as blocks:
with gr.Tabs():
_ = get_cls_tab(
examples_dir=args.examples_dir,
trained_model_dir=args.trained_model_dir,
)
_ = get_event_tab(
examples_dir=args.examples_dir,
trained_model_dir=args.trained_model_dir,
)
_ = get_split_tab(
examples_dir=args.examples_dir,
trained_model_dir=args.trained_model_dir,
)
_ = get_shell_tab()
# http://127.0.0.1:7864/
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()