+```
+
+## Options
+
+| Option | Description |
+| :-----------------------: | :---------------------------------------------------------------------------: |
+| --audio-dir | Directory containing audio or video files. |
+| --save-dir | Directory to save processed audio files. |
+| --device | Device to use for processing. Options: cuda (default) or cpu. |
+| --language | Language of the transcription. Default is auto. |
+| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
+| --punc | Enable punctuation prediction. |
+| --denoise | Enable noise reduction (vocal separation). |
+
+## Example
+
+To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
+```
+
+## Additional Notes
+
+- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
+- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
+- The script will automatically create necessary directories in the `--save-dir`.
+
+## Troubleshooting
+
+If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py
index dd2e186617fe889500d01d95eccdafc5c0248b84..fc290219cffb8813f5ae0cfc6801b51149261854 100644
--- a/tools/sensevoice/auto_model.py
+++ b/tools/sensevoice/auto_model.py
@@ -1,573 +1,573 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import copy
-import json
-import logging
-import os.path
-import random
-import re
-import string
-import time
-
-import numpy as np
-import torch
-from funasr.download.download_model_from_hub import download_model
-from funasr.download.file import download_from_url
-from funasr.register import tables
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import export_utils, misc
-from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
-from funasr.utils.misc import deep_update
-from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
-from tqdm import tqdm
-
-from .vad_utils import merge_vad, slice_padding_audio_samples
-
-try:
- from funasr.models.campplus.cluster_backend import ClusterBackend
- from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
-except:
- pass
-
-
-def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
- """ """
- data_list = []
- key_list = []
- filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
-
- chars = string.ascii_letters + string.digits
- if isinstance(data_in, str):
- if data_in.startswith("http://") or data_in.startswith("https://"): # url
- data_in = download_from_url(data_in)
-
- if isinstance(data_in, str) and os.path.exists(
- data_in
- ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
- _, file_extension = os.path.splitext(data_in)
- file_extension = file_extension.lower()
- if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
- with open(data_in, encoding="utf-8") as fin:
- for line in fin:
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
- if data_in.endswith(
- ".jsonl"
- ): # file.jsonl: json.dumps({"source": data})
- lines = json.loads(line.strip())
- data = lines["source"]
- key = data["key"] if "key" in data else key
- else: # filelist, wav.scp, text.txt: id \t data or data
- lines = line.strip().split(maxsplit=1)
- data = lines[1] if len(lines) > 1 else lines[0]
- key = lines[0] if len(lines) > 1 else key
-
- data_list.append(data)
- key_list.append(key)
- else:
- if key is None:
- # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
- key = misc.extract_filename_without_extension(data_in)
- data_list = [data_in]
- key_list = [key]
- elif isinstance(data_in, (list, tuple)):
- if data_type is not None and isinstance(
- data_type, (list, tuple)
- ): # mutiple inputs
- data_list_tmp = []
- for data_in_i, data_type_i in zip(data_in, data_type):
- key_list, data_list_i = prepare_data_iterator(
- data_in=data_in_i, data_type=data_type_i
- )
- data_list_tmp.append(data_list_i)
- data_list = []
- for item in zip(*data_list_tmp):
- data_list.append(item)
- else:
- # [audio sample point, fbank, text]
- data_list = data_in
- key_list = []
- for data_i in data_in:
- if isinstance(data_i, str) and os.path.exists(data_i):
- key = misc.extract_filename_without_extension(data_i)
- else:
- if key is None:
- key = "rand_key_" + "".join(
- random.choice(chars) for _ in range(13)
- )
- key_list.append(key)
-
- else: # raw text; audio sample point, fbank; bytes
- if isinstance(data_in, bytes): # audio bytes
- data_in = load_bytes(data_in)
- if key is None:
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
- data_list = [data_in]
- key_list = [key]
-
- return key_list, data_list
-
-
-class AutoModel:
-
- def __init__(self, **kwargs):
-
- try:
- from funasr.utils.version_checker import check_for_update
-
- print(
- "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
- )
- check_for_update(disable=kwargs.get("disable_update", False))
- except:
- pass
-
- log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
- logging.basicConfig(level=log_level)
-
- model, kwargs = self.build_model(**kwargs)
-
- # if vad_model is not None, build vad model else None
- vad_model = kwargs.get("vad_model", None)
- vad_kwargs = (
- {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
- )
- if vad_model is not None:
- logging.info("Building VAD model.")
- vad_kwargs["model"] = vad_model
- vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
- vad_kwargs["device"] = kwargs["device"]
- vad_model, vad_kwargs = self.build_model(**vad_kwargs)
-
- # if punc_model is not None, build punc model else None
- punc_model = kwargs.get("punc_model", None)
- punc_kwargs = (
- {}
- if kwargs.get("punc_kwargs", {}) is None
- else kwargs.get("punc_kwargs", {})
- )
- if punc_model is not None:
- logging.info("Building punc model.")
- punc_kwargs["model"] = punc_model
- punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
- punc_kwargs["device"] = kwargs["device"]
- punc_model, punc_kwargs = self.build_model(**punc_kwargs)
-
- # if spk_model is not None, build spk model else None
- spk_model = kwargs.get("spk_model", None)
- spk_kwargs = (
- {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
- )
- if spk_model is not None:
- logging.info("Building SPK model.")
- spk_kwargs["model"] = spk_model
- spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
- spk_kwargs["device"] = kwargs["device"]
- spk_model, spk_kwargs = self.build_model(**spk_kwargs)
- self.cb_model = ClusterBackend().to(kwargs["device"])
- spk_mode = kwargs.get("spk_mode", "punc_segment")
- if spk_mode not in ["default", "vad_segment", "punc_segment"]:
- logging.error(
- "spk_mode should be one of default, vad_segment and punc_segment."
- )
- self.spk_mode = spk_mode
-
- self.kwargs = kwargs
- self.model = model
- self.vad_model = vad_model
- self.vad_kwargs = vad_kwargs
- self.punc_model = punc_model
- self.punc_kwargs = punc_kwargs
- self.spk_model = spk_model
- self.spk_kwargs = spk_kwargs
- self.model_path = kwargs.get("model_path")
-
- @staticmethod
- def build_model(**kwargs):
- assert "model" in kwargs
- if "model_conf" not in kwargs:
- logging.info(
- "download models from model hub: {}".format(kwargs.get("hub", "ms"))
- )
- kwargs = download_model(**kwargs)
-
- set_all_random_seed(kwargs.get("seed", 0))
-
- device = kwargs.get("device", "cuda")
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
- device = "cpu"
- kwargs["batch_size"] = 1
- kwargs["device"] = device
-
- torch.set_num_threads(kwargs.get("ncpu", 4))
-
- # build tokenizer
- tokenizer = kwargs.get("tokenizer", None)
- if tokenizer is not None:
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
- tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
- kwargs["token_list"] = (
- tokenizer.token_list if hasattr(tokenizer, "token_list") else None
- )
- kwargs["token_list"] = (
- tokenizer.get_vocab()
- if hasattr(tokenizer, "get_vocab")
- else kwargs["token_list"]
- )
- vocab_size = (
- len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
- )
- if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
- vocab_size = tokenizer.get_vocab_size()
- else:
- vocab_size = -1
- kwargs["tokenizer"] = tokenizer
-
- # build frontend
- frontend = kwargs.get("frontend", None)
- kwargs["input_size"] = None
- if frontend is not None:
- frontend_class = tables.frontend_classes.get(frontend)
- frontend = frontend_class(**kwargs.get("frontend_conf", {}))
- kwargs["input_size"] = (
- frontend.output_size() if hasattr(frontend, "output_size") else None
- )
- kwargs["frontend"] = frontend
- # build model
- model_class = tables.model_classes.get(kwargs["model"])
- assert model_class is not None, f'{kwargs["model"]} is not registered'
- model_conf = {}
- deep_update(model_conf, kwargs.get("model_conf", {}))
- deep_update(model_conf, kwargs)
- model = model_class(**model_conf, vocab_size=vocab_size)
-
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- if os.path.exists(init_param):
- logging.info(f"Loading pretrained params from {init_param}")
- load_pretrained_model(
- model=model,
- path=init_param,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
- oss_bucket=kwargs.get("oss_bucket", None),
- scope_map=kwargs.get("scope_map", []),
- excludes=kwargs.get("excludes", None),
- )
- else:
- print(f"error, init_param does not exist!: {init_param}")
-
- # fp16
- if kwargs.get("fp16", False):
- model.to(torch.float16)
- elif kwargs.get("bf16", False):
- model.to(torch.bfloat16)
- model.to(device)
-
- if not kwargs.get("disable_log", True):
- tables.print()
-
- return model, kwargs
-
- def __call__(self, *args, **cfg):
- kwargs = self.kwargs
- deep_update(kwargs, cfg)
- res = self.model(*args, kwargs)
- return res
-
- def generate(self, input, input_len=None, **cfg):
- if self.vad_model is None:
- return self.inference(input, input_len=input_len, **cfg)
-
- else:
- return self.inference_with_vad(input, input_len=input_len, **cfg)
-
- def inference(
- self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
- ):
- kwargs = self.kwargs if kwargs is None else kwargs
- if "cache" in kwargs:
- kwargs.pop("cache")
- deep_update(kwargs, cfg)
- model = self.model if model is None else model
- model.eval()
-
- batch_size = kwargs.get("batch_size", 1)
- # if kwargs.get("device", "cpu") == "cpu":
- # batch_size = 1
-
- key_list, data_list = prepare_data_iterator(
- input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
- )
-
- speed_stats = {}
- asr_result_list = []
- num_samples = len(data_list)
- disable_pbar = self.kwargs.get("disable_pbar", False)
- pbar = (
- tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
- if not disable_pbar
- else None
- )
- time_speech_total = 0.0
- time_escape_total = 0.0
- for beg_idx in range(0, num_samples, batch_size):
- end_idx = min(num_samples, beg_idx + batch_size)
- data_batch = data_list[beg_idx:end_idx]
- key_batch = key_list[beg_idx:end_idx]
- batch = {"data_in": data_batch, "key": key_batch}
-
- if (end_idx - beg_idx) == 1 and kwargs.get(
- "data_type", None
- ) == "fbank": # fbank
- batch["data_in"] = data_batch[0]
- batch["data_lengths"] = input_len
-
- time1 = time.perf_counter()
- with torch.no_grad():
- res = model.inference(**batch, **kwargs)
- if isinstance(res, (list, tuple)):
- results = res[0] if len(res) > 0 else [{"text": ""}]
- meta_data = res[1] if len(res) > 1 else {}
- time2 = time.perf_counter()
-
- asr_result_list.extend(results)
-
- # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
- batch_data_time = meta_data.get("batch_data_time", -1)
- time_escape = time2 - time1
- speed_stats["load_data"] = meta_data.get("load_data", 0.0)
- speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
- speed_stats["forward"] = f"{time_escape:0.3f}"
- speed_stats["batch_size"] = f"{len(results)}"
- speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
- description = f"{speed_stats}, "
- if pbar:
- pbar.update(end_idx - beg_idx)
- pbar.set_description(description)
- time_speech_total += batch_data_time
- time_escape_total += time_escape
-
- if pbar:
- # pbar.update(1)
- pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
- torch.cuda.empty_cache()
- return asr_result_list
-
- def vad(self, input, input_len=None, **cfg):
- kwargs = self.kwargs
- # step.1: compute the vad model
- deep_update(self.vad_kwargs, cfg)
- beg_vad = time.time()
- res = self.inference(
- input,
- input_len=input_len,
- model=self.vad_model,
- kwargs=self.vad_kwargs,
- **cfg,
- )
- end_vad = time.time()
- # FIX(gcf): concat the vad clips for sense vocie model for better aed
- if cfg.get("merge_vad", False):
- for i in range(len(res)):
- res[i]["value"] = merge_vad(
- res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
- )
- elapsed = end_vad - beg_vad
- return elapsed, res
-
- def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
-
- kwargs = self.kwargs
-
- # step.2 compute asr model
- model = self.model
- deep_update(kwargs, cfg)
- batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
- batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
- kwargs["batch_size"] = batch_size
-
- key_list, data_list = prepare_data_iterator(
- input, input_len=input_len, data_type=kwargs.get("data_type", None)
- )
- results_ret_list = []
- time_speech_total_all_samples = 1e-6
-
- beg_total = time.time()
- pbar_total = (
- tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
- if not kwargs.get("disable_pbar", False)
- else None
- )
-
- for i in range(len(vad_res)):
- key = vad_res[i]["key"]
- vadsegments = vad_res[i]["value"]
- input_i = data_list[i]
- fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
- speech = load_audio_text_image_video(
- input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
- )
- speech_lengths = len(speech)
- n = len(vadsegments)
- data_with_index = [(vadsegments[i], i) for i in range(n)]
- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
- results_sorted = []
-
- if not len(sorted_data):
- results_ret_list.append({"key": key, "text": "", "timestamp": []})
- logging.info("decoding, utt: {}, empty speech".format(key))
- continue
-
- if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
- batch_size = max(
- batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
- )
-
- if kwargs["device"] == "cpu":
- batch_size = 0
-
- beg_idx = 0
- beg_asr_total = time.time()
- time_speech_total_per_sample = speech_lengths / 16000
- time_speech_total_all_samples += time_speech_total_per_sample
-
- # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
-
- all_segments = []
- max_len_in_batch = 0
- end_idx = 1
-
- for j, _ in enumerate(range(0, n)):
- # pbar_sample.update(1)
- sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
- potential_batch_length = max(max_len_in_batch, sample_length) * (
- j + 1 - beg_idx
- )
- # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
- if (
- j < n - 1
- and sample_length < batch_size_threshold_ms
- and potential_batch_length < batch_size
- ):
- max_len_in_batch = max(max_len_in_batch, sample_length)
- end_idx += 1
- continue
-
- speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
- speech, speech_lengths, sorted_data[beg_idx:end_idx]
- )
- results = self.inference(
- speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
- )
-
- for _b in range(len(speech_j)):
- results[_b]["interval"] = intervals[_b]
-
- if self.spk_model is not None:
- # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
- for _b in range(len(speech_j)):
- vad_segments = [
- [
- sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
- sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
- np.array(speech_j[_b]),
- ]
- ]
- segments = sv_chunk(vad_segments)
- all_segments.extend(segments)
- speech_b = [i[2] for i in segments]
- spk_res = self.inference(
- speech_b,
- input_len=None,
- model=self.spk_model,
- kwargs=kwargs,
- **cfg,
- )
- results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
-
- beg_idx = end_idx
- end_idx += 1
- max_len_in_batch = sample_length
- if len(results) < 1:
- continue
- results_sorted.extend(results)
-
- # end_asr_total = time.time()
- # time_escape_total_per_sample = end_asr_total - beg_asr_total
- # pbar_sample.update(1)
- # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
- # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
- # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
-
- restored_data = [0] * n
- for j in range(n):
- index = sorted_data[j][1]
- cur = results_sorted[j]
- pattern = r"<\|([^|]+)\|>"
- emotion_string = re.findall(pattern, cur["text"])
- cur["text"] = re.sub(pattern, "", cur["text"])
- cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
- if self.punc_model is not None and len(cur["text"].strip()) > 0:
- deep_update(self.punc_kwargs, cfg)
- punc_res = self.inference(
- cur["text"],
- model=self.punc_model,
- kwargs=self.punc_kwargs,
- **cfg,
- )
- cur["text"] = punc_res[0]["text"]
-
- restored_data[index] = cur
-
- end_asr_total = time.time()
- time_escape_total_per_sample = end_asr_total - beg_asr_total
- if pbar_total:
- pbar_total.update(1)
- pbar_total.set_description(
- f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
- f"time_speech: {time_speech_total_per_sample: 0.3f}, "
- f"time_escape: {time_escape_total_per_sample:0.3f}"
- )
-
- # end_total = time.time()
- # time_escape_total_all_samples = end_total - beg_total
- # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
- # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
- # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
- return restored_data
-
- def export(self, input=None, **cfg):
- """
-
- :param input:
- :param type:
- :param quantize:
- :param fallback_num:
- :param calib_num:
- :param opset_version:
- :param cfg:
- :return:
- """
-
- device = cfg.get("device", "cpu")
- model = self.model.to(device=device)
- kwargs = self.kwargs
- deep_update(kwargs, cfg)
- kwargs["device"] = device
- del kwargs["model"]
- model.eval()
-
- type = kwargs.get("type", "onnx")
-
- key_list, data_list = prepare_data_iterator(
- input, input_len=None, data_type=kwargs.get("data_type", None), key=None
- )
-
- with torch.no_grad():
- export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
-
- return export_dir
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import copy
+import json
+import logging
+import os.path
+import random
+import re
+import string
+import time
+
+import numpy as np
+import torch
+from funasr.download.download_model_from_hub import download_model
+from funasr.download.file import download_from_url
+from funasr.register import tables
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import export_utils, misc
+from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
+from funasr.utils.misc import deep_update
+from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
+from tqdm import tqdm
+
+from .vad_utils import merge_vad, slice_padding_audio_samples
+
+try:
+ from funasr.models.campplus.cluster_backend import ClusterBackend
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
+except:
+ pass
+
+
+def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
+ """ """
+ data_list = []
+ key_list = []
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
+
+ chars = string.ascii_letters + string.digits
+ if isinstance(data_in, str):
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
+ data_in = download_from_url(data_in)
+
+ if isinstance(data_in, str) and os.path.exists(
+ data_in
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
+ _, file_extension = os.path.splitext(data_in)
+ file_extension = file_extension.lower()
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
+ with open(data_in, encoding="utf-8") as fin:
+ for line in fin:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ if data_in.endswith(
+ ".jsonl"
+ ): # file.jsonl: json.dumps({"source": data})
+ lines = json.loads(line.strip())
+ data = lines["source"]
+ key = data["key"] if "key" in data else key
+ else: # filelist, wav.scp, text.txt: id \t data or data
+ lines = line.strip().split(maxsplit=1)
+ data = lines[1] if len(lines) > 1 else lines[0]
+ key = lines[0] if len(lines) > 1 else key
+
+ data_list.append(data)
+ key_list.append(key)
+ else:
+ if key is None:
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key = misc.extract_filename_without_extension(data_in)
+ data_list = [data_in]
+ key_list = [key]
+ elif isinstance(data_in, (list, tuple)):
+ if data_type is not None and isinstance(
+ data_type, (list, tuple)
+ ): # mutiple inputs
+ data_list_tmp = []
+ for data_in_i, data_type_i in zip(data_in, data_type):
+ key_list, data_list_i = prepare_data_iterator(
+ data_in=data_in_i, data_type=data_type_i
+ )
+ data_list_tmp.append(data_list_i)
+ data_list = []
+ for item in zip(*data_list_tmp):
+ data_list.append(item)
+ else:
+ # [audio sample point, fbank, text]
+ data_list = data_in
+ key_list = []
+ for data_i in data_in:
+ if isinstance(data_i, str) and os.path.exists(data_i):
+ key = misc.extract_filename_without_extension(data_i)
+ else:
+ if key is None:
+ key = "rand_key_" + "".join(
+ random.choice(chars) for _ in range(13)
+ )
+ key_list.append(key)
+
+ else: # raw text; audio sample point, fbank; bytes
+ if isinstance(data_in, bytes): # audio bytes
+ data_in = load_bytes(data_in)
+ if key is None:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+
+ return key_list, data_list
+
+
+class AutoModel:
+
+ def __init__(self, **kwargs):
+
+ try:
+ from funasr.utils.version_checker import check_for_update
+
+ print(
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
+ )
+ check_for_update(disable=kwargs.get("disable_update", False))
+ except:
+ pass
+
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ logging.basicConfig(level=log_level)
+
+ model, kwargs = self.build_model(**kwargs)
+
+ # if vad_model is not None, build vad model else None
+ vad_model = kwargs.get("vad_model", None)
+ vad_kwargs = (
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
+ )
+ if vad_model is not None:
+ logging.info("Building VAD model.")
+ vad_kwargs["model"] = vad_model
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
+ vad_kwargs["device"] = kwargs["device"]
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+ # if punc_model is not None, build punc model else None
+ punc_model = kwargs.get("punc_model", None)
+ punc_kwargs = (
+ {}
+ if kwargs.get("punc_kwargs", {}) is None
+ else kwargs.get("punc_kwargs", {})
+ )
+ if punc_model is not None:
+ logging.info("Building punc model.")
+ punc_kwargs["model"] = punc_model
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
+ punc_kwargs["device"] = kwargs["device"]
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+ # if spk_model is not None, build spk model else None
+ spk_model = kwargs.get("spk_model", None)
+ spk_kwargs = (
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+ )
+ if spk_model is not None:
+ logging.info("Building SPK model.")
+ spk_kwargs["model"] = spk_model
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
+ spk_kwargs["device"] = kwargs["device"]
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
+ self.cb_model = ClusterBackend().to(kwargs["device"])
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
+ logging.error(
+ "spk_mode should be one of default, vad_segment and punc_segment."
+ )
+ self.spk_mode = spk_mode
+
+ self.kwargs = kwargs
+ self.model = model
+ self.vad_model = vad_model
+ self.vad_kwargs = vad_kwargs
+ self.punc_model = punc_model
+ self.punc_kwargs = punc_kwargs
+ self.spk_model = spk_model
+ self.spk_kwargs = spk_kwargs
+ self.model_path = kwargs.get("model_path")
+
+ @staticmethod
+ def build_model(**kwargs):
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info(
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
+ )
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ torch.set_num_threads(kwargs.get("ncpu", 4))
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
+ kwargs["token_list"] = (
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ )
+ kwargs["token_list"] = (
+ tokenizer.get_vocab()
+ if hasattr(tokenizer, "get_vocab")
+ else kwargs["token_list"]
+ )
+ vocab_size = (
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+ )
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ else:
+ vocab_size = -1
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ kwargs["input_size"] = None
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get(frontend)
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
+ kwargs["input_size"] = (
+ frontend.output_size() if hasattr(frontend, "output_size") else None
+ )
+ kwargs["frontend"] = frontend
+ # build model
+ model_class = tables.model_classes.get(kwargs["model"])
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
+ model_conf = {}
+ deep_update(model_conf, kwargs.get("model_conf", {}))
+ deep_update(model_conf, kwargs)
+ model = model_class(**model_conf, vocab_size=vocab_size)
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ path=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ )
+ else:
+ print(f"error, init_param does not exist!: {init_param}")
+
+ # fp16
+ if kwargs.get("fp16", False):
+ model.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ model.to(torch.bfloat16)
+ model.to(device)
+
+ if not kwargs.get("disable_log", True):
+ tables.print()
+
+ return model, kwargs
+
+ def __call__(self, *args, **cfg):
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ res = self.model(*args, kwargs)
+ return res
+
+ def generate(self, input, input_len=None, **cfg):
+ if self.vad_model is None:
+ return self.inference(input, input_len=input_len, **cfg)
+
+ else:
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
+
+ def inference(
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
+ ):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ if "cache" in kwargs:
+ kwargs.pop("cache")
+ deep_update(kwargs, cfg)
+ model = self.model if model is None else model
+ model.eval()
+
+ batch_size = kwargs.get("batch_size", 1)
+ # if kwargs.get("device", "cpu") == "cpu":
+ # batch_size = 1
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
+ )
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
+ pbar = (
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ if not disable_pbar
+ else None
+ )
+ time_speech_total = 0.0
+ time_escape_total = 0.0
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+
+ if (end_idx - beg_idx) == 1 and kwargs.get(
+ "data_type", None
+ ) == "fbank": # fbank
+ batch["data_in"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ with torch.no_grad():
+ res = model.inference(**batch, **kwargs)
+ if isinstance(res, (list, tuple)):
+ results = res[0] if len(res) > 0 else [{"text": ""}]
+ meta_data = res[1] if len(res) > 1 else {}
+ time2 = time.perf_counter()
+
+ asr_result_list.extend(results)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+ description = f"{speed_stats}, "
+ if pbar:
+ pbar.update(end_idx - beg_idx)
+ pbar.set_description(description)
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+
+ if pbar:
+ # pbar.update(1)
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+ torch.cuda.empty_cache()
+ return asr_result_list
+
+ def vad(self, input, input_len=None, **cfg):
+ kwargs = self.kwargs
+ # step.1: compute the vad model
+ deep_update(self.vad_kwargs, cfg)
+ beg_vad = time.time()
+ res = self.inference(
+ input,
+ input_len=input_len,
+ model=self.vad_model,
+ kwargs=self.vad_kwargs,
+ **cfg,
+ )
+ end_vad = time.time()
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
+ if cfg.get("merge_vad", False):
+ for i in range(len(res)):
+ res[i]["value"] = merge_vad(
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
+ )
+ elapsed = end_vad - beg_vad
+ return elapsed, res
+
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
+
+ kwargs = self.kwargs
+
+ # step.2 compute asr model
+ model = self.model
+ deep_update(kwargs, cfg)
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
+ kwargs["batch_size"] = batch_size
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
+ )
+ results_ret_list = []
+ time_speech_total_all_samples = 1e-6
+
+ beg_total = time.time()
+ pbar_total = (
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
+ if not kwargs.get("disable_pbar", False)
+ else None
+ )
+
+ for i in range(len(vad_res)):
+ key = vad_res[i]["key"]
+ vadsegments = vad_res[i]["value"]
+ input_i = data_list[i]
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
+ speech = load_audio_text_image_video(
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
+ )
+ speech_lengths = len(speech)
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size = max(
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
+ )
+
+ if kwargs["device"] == "cpu":
+ batch_size = 0
+
+ beg_idx = 0
+ beg_asr_total = time.time()
+ time_speech_total_per_sample = speech_lengths / 16000
+ time_speech_total_all_samples += time_speech_total_per_sample
+
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
+
+ all_segments = []
+ max_len_in_batch = 0
+ end_idx = 1
+
+ for j, _ in enumerate(range(0, n)):
+ # pbar_sample.update(1)
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
+ j + 1 - beg_idx
+ )
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+ if (
+ j < n - 1
+ and sample_length < batch_size_threshold_ms
+ and potential_batch_length < batch_size
+ ):
+ max_len_in_batch = max(max_len_in_batch, sample_length)
+ end_idx += 1
+ continue
+
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
+ )
+ results = self.inference(
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
+ )
+
+ for _b in range(len(speech_j)):
+ results[_b]["interval"] = intervals[_b]
+
+ if self.spk_model is not None:
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
+ for _b in range(len(speech_j)):
+ vad_segments = [
+ [
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
+ np.array(speech_j[_b]),
+ ]
+ ]
+ segments = sv_chunk(vad_segments)
+ all_segments.extend(segments)
+ speech_b = [i[2] for i in segments]
+ spk_res = self.inference(
+ speech_b,
+ input_len=None,
+ model=self.spk_model,
+ kwargs=kwargs,
+ **cfg,
+ )
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
+
+ beg_idx = end_idx
+ end_idx += 1
+ max_len_in_batch = sample_length
+ if len(results) < 1:
+ continue
+ results_sorted.extend(results)
+
+ # end_asr_total = time.time()
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
+ # pbar_sample.update(1)
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ cur = results_sorted[j]
+ pattern = r"<\|([^|]+)\|>"
+ emotion_string = re.findall(pattern, cur["text"])
+ cur["text"] = re.sub(pattern, "", cur["text"])
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
+ deep_update(self.punc_kwargs, cfg)
+ punc_res = self.inference(
+ cur["text"],
+ model=self.punc_model,
+ kwargs=self.punc_kwargs,
+ **cfg,
+ )
+ cur["text"] = punc_res[0]["text"]
+
+ restored_data[index] = cur
+
+ end_asr_total = time.time()
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
+ if pbar_total:
+ pbar_total.update(1)
+ pbar_total.set_description(
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
+ )
+
+ # end_total = time.time()
+ # time_escape_total_all_samples = end_total - beg_total
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
+ return restored_data
+
+ def export(self, input=None, **cfg):
+ """
+
+ :param input:
+ :param type:
+ :param quantize:
+ :param fallback_num:
+ :param calib_num:
+ :param opset_version:
+ :param cfg:
+ :return:
+ """
+
+ device = cfg.get("device", "cpu")
+ model = self.model.to(device=device)
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ kwargs["device"] = device
+ del kwargs["model"]
+ model.eval()
+
+ type = kwargs.get("type", "onnx")
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
+ )
+
+ with torch.no_grad():
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
+
+ return export_dir
diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py
index 6789316d5186db69c021758094649553c3638f66..dab2cba6df708eacdb5b7e453ef306bdb89619cf 100644
--- a/tools/sensevoice/fun_asr.py
+++ b/tools/sensevoice/fun_asr.py
@@ -1,332 +1,332 @@
-import gc
-import os
-import re
-
-from audio_separator.separator import Separator
-
-os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
-os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
-import json
-import subprocess
-from pathlib import Path
-
-import click
-import torch
-from loguru import logger
-from pydub import AudioSegment
-from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
-from tools.sensevoice.auto_model import AutoModel
-
-
-def uvr5_cli(
- audio_dir: Path,
- output_folder: Path,
- audio_files: list[Path] | None = None,
- output_format: str = "flac",
- model: str = "BS-Roformer-Viperx-1297.ckpt",
-):
- # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
- sepr = Separator(
- model_file_dir=os.environ["UVR5_CACHE"],
- output_dir=output_folder,
- output_format=output_format,
- )
- dictmodel = {
- "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
- "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
- "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
- "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
- }
- roformer_model = dictmodel[model]
- sepr.load_model(roformer_model)
- if audio_files is None:
- audio_files = list_files(
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
- )
- total_files = len(audio_files)
-
- print(f"{total_files} audio files found")
-
- res = []
- for audio in tqdm(audio_files, desc="Denoising: "):
- file_path = str(audio_dir / audio)
- sep_out = sepr.separate(file_path)
- if isinstance(sep_out, str):
- res.append(sep_out)
- elif isinstance(sep_out, list):
- res.extend(sep_out)
- del sepr
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
- return res, roformer_model
-
-
-def get_sample_rate(media_path: Path):
- result = subprocess.run(
- [
- "ffprobe",
- "-v",
- "quiet",
- "-print_format",
- "json",
- "-show_streams",
- str(media_path),
- ],
- capture_output=True,
- text=True,
- check=True,
- )
- media_info = json.loads(result.stdout)
- for stream in media_info.get("streams", []):
- if stream.get("codec_type") == "audio":
- return stream.get("sample_rate")
- return "44100" # Default sample rate if not found
-
-
-def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
- sr = get_sample_rate(src_path)
- out_path.parent.mkdir(parents=True, exist_ok=True)
- if src_path.resolve() == out_path.resolve():
- output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
- else:
- output = str(out_path)
- subprocess.run(
- [
- "ffmpeg",
- "-loglevel",
- "error",
- "-i",
- str(src_path),
- "-acodec",
- "pcm_s16le" if out_fmt == "wav" else "flac",
- "-ar",
- sr,
- "-ac",
- "1",
- "-y",
- output,
- ],
- check=True,
- )
- return out_path
-
-
-def convert_video_to_audio(video_path: Path, audio_dir: Path):
- cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
- vocals = [
- p
- for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
- if p.suffix in AUDIO_EXTENSIONS
- ]
- if len(vocals) > 0:
- return vocals[0]
- audio_path = cur_dir / f"{video_path.stem}.wav"
- convert_to_mono(video_path, audio_path)
- return audio_path
-
-
-@click.command()
-@click.option("--audio-dir", required=True, help="Directory containing audio files")
-@click.option(
- "--save-dir", required=True, help="Directory to save processed audio files"
-)
-@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
-@click.option("--language", default="auto", help="Language of the transcription")
-@click.option(
- "--max_single_segment_time",
- default=20000,
- type=int,
- help="Maximum of Output single audio duration(ms)",
-)
-@click.option("--fsmn-vad/--silero-vad", default=False)
-@click.option("--punc/--no-punc", default=False)
-@click.option("--denoise/--no-denoise", default=False)
-@click.option("--save_emo/--no_save_emo", default=False)
-def main(
- audio_dir: str,
- save_dir: str,
- device: str,
- language: str,
- max_single_segment_time: int,
- fsmn_vad: bool,
- punc: bool,
- denoise: bool,
- save_emo: bool,
-):
-
- audios_path = Path(audio_dir)
- save_path = Path(save_dir)
- save_path.mkdir(parents=True, exist_ok=True)
-
- video_files = list_files(
- path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
- )
- v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
-
- if denoise:
- VOCAL = "_(Vocals)"
- original_files = [
- p
- for p in audios_path.glob("**/*")
- if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
- ]
-
- _, cur_model = uvr5_cli(
- audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
- )
- need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
- need_remove.extend(original_files)
- for _ in need_remove:
- _.unlink()
- vocal_files = [
- p
- for p in audios_path.glob("**/*")
- if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
- ]
- for f in vocal_files:
- fn, ext = f.stem, f.suffix
-
- v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
- if v_pos != -1:
- new_fn = fn[: v_pos + len(VOCAL)]
- new_f = f.with_name(new_fn + ext)
- f = f.rename(new_f)
- convert_to_mono(f, f, "flac")
- f.unlink()
-
- audio_files = list_files(
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
- )
-
- logger.info("Loading / Downloading Funasr model...")
-
- model_dir = "iic/SenseVoiceSmall"
-
- vad_model = "fsmn-vad" if fsmn_vad else None
- vad_kwargs = {"max_single_segment_time": max_single_segment_time}
- punc_model = "ct-punc" if punc else None
-
- manager = AutoModel(
- model=model_dir,
- trust_remote_code=False,
- vad_model=vad_model,
- vad_kwargs=vad_kwargs,
- punc_model=punc_model,
- device=device,
- )
-
- if not fsmn_vad and vad_model is None:
- vad_model = load_silero_vad()
-
- logger.info("Model loaded.")
-
- pattern = re.compile(r"_\d{3}\.")
-
- for file_path in tqdm(audio_files, desc="Processing audio file"):
-
- if pattern.search(file_path.name):
- # logger.info(f"Skipping {file_path} as it has already been processed.")
- continue
-
- file_stem = file_path.stem
- file_suffix = file_path.suffix
-
- rel_path = Path(file_path).relative_to(audio_dir)
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
-
- audio = AudioSegment.from_file(file_path)
-
- cfg = dict(
- cache={},
- language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
- use_itn=False,
- batch_size_s=60,
- )
-
- if fsmn_vad:
- elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
- else:
- wav = read_audio(
- str(file_path)
- ) # backend (sox, soundfile, or ffmpeg) required!
- audio_key = file_path.stem
- audio_val = []
- speech_timestamps = get_speech_timestamps(
- wav,
- vad_model,
- max_speech_duration_s=max_single_segment_time // 1000,
- return_seconds=True,
- )
-
- audio_val = [
- [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
- for timestamp in speech_timestamps
- ]
- vad_res = []
- vad_res.append(dict(key=audio_key, value=audio_val))
-
- res = manager.inference_with_vadres(
- input=str(file_path), vad_res=vad_res, **cfg
- )
-
- for i, info in enumerate(res):
- [start_ms, end_ms] = info["interval"]
- text = info["text"]
- emo = info["emo"]
- sliced_audio = audio[start_ms:end_ms]
- audio_save_path = (
- save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
- )
- sliced_audio.export(audio_save_path, format=file_suffix[1:])
- print(f"Exported {audio_save_path}: {text}")
-
- transcript_save_path = (
- save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
- )
- with open(
- transcript_save_path,
- "w",
- encoding="utf-8",
- ) as f:
- f.write(text)
-
- if save_emo:
- emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
- with open(
- emo_save_path,
- "w",
- encoding="utf-8",
- ) as f:
- f.write(emo)
-
- if audios_path.resolve() == save_path.resolve():
- file_path.unlink()
-
-
-if __name__ == "__main__":
- main()
- exit(0)
- from funasr.utils.postprocess_utils import rich_transcription_postprocess
-
- # Load the audio file
- audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
- model_dir = "iic/SenseVoiceSmall"
- m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
- m.eval()
-
- res = m.inference(
- data_in=f"{kwargs['model_path']}/example/zh.mp3",
- language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
- use_itn=False,
- ban_emo_unk=False,
- **kwargs,
- )
-
- print(res)
- text = rich_transcription_postprocess(res[0][0]["text"])
- print(text)
+import gc
+import os
+import re
+
+from audio_separator.separator import Separator
+
+os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
+os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
+import json
+import subprocess
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+from pydub import AudioSegment
+from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from tools.sensevoice.auto_model import AutoModel
+
+
+def uvr5_cli(
+ audio_dir: Path,
+ output_folder: Path,
+ audio_files: list[Path] | None = None,
+ output_format: str = "flac",
+ model: str = "BS-Roformer-Viperx-1297.ckpt",
+):
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
+ sepr = Separator(
+ model_file_dir=os.environ["UVR5_CACHE"],
+ output_dir=output_folder,
+ output_format=output_format,
+ )
+ dictmodel = {
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+ }
+ roformer_model = dictmodel[model]
+ sepr.load_model(roformer_model)
+ if audio_files is None:
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+ total_files = len(audio_files)
+
+ print(f"{total_files} audio files found")
+
+ res = []
+ for audio in tqdm(audio_files, desc="Denoising: "):
+ file_path = str(audio_dir / audio)
+ sep_out = sepr.separate(file_path)
+ if isinstance(sep_out, str):
+ res.append(sep_out)
+ elif isinstance(sep_out, list):
+ res.extend(sep_out)
+ del sepr
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return res, roformer_model
+
+
+def get_sample_rate(media_path: Path):
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "quiet",
+ "-print_format",
+ "json",
+ "-show_streams",
+ str(media_path),
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ )
+ media_info = json.loads(result.stdout)
+ for stream in media_info.get("streams", []):
+ if stream.get("codec_type") == "audio":
+ return stream.get("sample_rate")
+ return "44100" # Default sample rate if not found
+
+
+def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
+ sr = get_sample_rate(src_path)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ if src_path.resolve() == out_path.resolve():
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
+ else:
+ output = str(out_path)
+ subprocess.run(
+ [
+ "ffmpeg",
+ "-loglevel",
+ "error",
+ "-i",
+ str(src_path),
+ "-acodec",
+ "pcm_s16le" if out_fmt == "wav" else "flac",
+ "-ar",
+ sr,
+ "-ac",
+ "1",
+ "-y",
+ output,
+ ],
+ check=True,
+ )
+ return out_path
+
+
+def convert_video_to_audio(video_path: Path, audio_dir: Path):
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
+ vocals = [
+ p
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
+ if p.suffix in AUDIO_EXTENSIONS
+ ]
+ if len(vocals) > 0:
+ return vocals[0]
+ audio_path = cur_dir / f"{video_path.stem}.wav"
+ convert_to_mono(video_path, audio_path)
+ return audio_path
+
+
+@click.command()
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option(
+ "--max_single_segment_time",
+ default=20000,
+ type=int,
+ help="Maximum of Output single audio duration(ms)",
+)
+@click.option("--fsmn-vad/--silero-vad", default=False)
+@click.option("--punc/--no-punc", default=False)
+@click.option("--denoise/--no-denoise", default=False)
+@click.option("--save_emo/--no_save_emo", default=False)
+def main(
+ audio_dir: str,
+ save_dir: str,
+ device: str,
+ language: str,
+ max_single_segment_time: int,
+ fsmn_vad: bool,
+ punc: bool,
+ denoise: bool,
+ save_emo: bool,
+):
+
+ audios_path = Path(audio_dir)
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ video_files = list_files(
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
+ )
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
+
+ if denoise:
+ VOCAL = "_(Vocals)"
+ original_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
+ ]
+
+ _, cur_model = uvr5_cli(
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
+ )
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
+ need_remove.extend(original_files)
+ for _ in need_remove:
+ _.unlink()
+ vocal_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
+ ]
+ for f in vocal_files:
+ fn, ext = f.stem, f.suffix
+
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
+ if v_pos != -1:
+ new_fn = fn[: v_pos + len(VOCAL)]
+ new_f = f.with_name(new_fn + ext)
+ f = f.rename(new_f)
+ convert_to_mono(f, f, "flac")
+ f.unlink()
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ logger.info("Loading / Downloading Funasr model...")
+
+ model_dir = "iic/SenseVoiceSmall"
+
+ vad_model = "fsmn-vad" if fsmn_vad else None
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
+ punc_model = "ct-punc" if punc else None
+
+ manager = AutoModel(
+ model=model_dir,
+ trust_remote_code=False,
+ vad_model=vad_model,
+ vad_kwargs=vad_kwargs,
+ punc_model=punc_model,
+ device=device,
+ )
+
+ if not fsmn_vad and vad_model is None:
+ vad_model = load_silero_vad()
+
+ logger.info("Model loaded.")
+
+ pattern = re.compile(r"_\d{3}\.")
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+
+ if pattern.search(file_path.name):
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
+ continue
+
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ cfg = dict(
+ cache={},
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ batch_size_s=60,
+ )
+
+ if fsmn_vad:
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
+ else:
+ wav = read_audio(
+ str(file_path)
+ ) # backend (sox, soundfile, or ffmpeg) required!
+ audio_key = file_path.stem
+ audio_val = []
+ speech_timestamps = get_speech_timestamps(
+ wav,
+ vad_model,
+ max_speech_duration_s=max_single_segment_time // 1000,
+ return_seconds=True,
+ )
+
+ audio_val = [
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
+ for timestamp in speech_timestamps
+ ]
+ vad_res = []
+ vad_res.append(dict(key=audio_key, value=audio_val))
+
+ res = manager.inference_with_vadres(
+ input=str(file_path), vad_res=vad_res, **cfg
+ )
+
+ for i, info in enumerate(res):
+ [start_ms, end_ms] = info["interval"]
+ text = info["text"]
+ emo = info["emo"]
+ sliced_audio = audio[start_ms:end_ms]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
+ )
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}: {text}")
+
+ transcript_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
+ )
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(text)
+
+ if save_emo:
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
+ with open(
+ emo_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(emo)
+
+ if audios_path.resolve() == save_path.resolve():
+ file_path.unlink()
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+ # Load the audio file
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
+ model_dir = "iic/SenseVoiceSmall"
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
+ m.eval()
+
+ res = m.inference(
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ ban_emo_unk=False,
+ **kwargs,
+ )
+
+ print(res)
+ text = rich_transcription_postprocess(res[0][0]["text"])
+ print(text)
diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py
index 3bef75ed8c2841701fff44f7130e91ef8dfdf8cc..d1fbe4974da382ef29055462f1ef2443b5980fe7 100644
--- a/tools/sensevoice/vad_utils.py
+++ b/tools/sensevoice/vad_utils.py
@@ -1,61 +1,61 @@
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-
-def slice_padding_fbank(speech, speech_lengths, vad_segments):
- speech_list = []
- speech_lengths_list = []
- for i, segment in enumerate(vad_segments):
-
- bed_idx = int(segment[0][0] * 16)
- end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
- speech_i = speech[0, bed_idx:end_idx]
- speech_lengths_i = end_idx - bed_idx
- speech_list.append(speech_i)
- speech_lengths_list.append(speech_lengths_i)
- feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
- speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
- return feats_pad, speech_lengths_pad
-
-
-def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
- speech_list = []
- speech_lengths_list = []
- intervals = []
- for i, segment in enumerate(vad_segments):
- bed_idx = int(segment[0][0] * 16)
- end_idx = min(int(segment[0][1] * 16), speech_lengths)
- speech_i = speech[bed_idx:end_idx]
- speech_lengths_i = end_idx - bed_idx
- speech_list.append(speech_i)
- speech_lengths_list.append(speech_lengths_i)
- intervals.append([bed_idx // 16, end_idx // 16])
-
- return speech_list, speech_lengths_list, intervals
-
-
-def merge_vad(vad_result, max_length=15000, min_length=0):
- new_result = []
- if len(vad_result) <= 1:
- return vad_result
- time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
- time_step = sorted(list(set(time_step)))
- if len(time_step) == 0:
- return []
- bg = 0
- for i in range(len(time_step) - 1):
- time = time_step[i]
- if time_step[i + 1] - bg < max_length:
- continue
- if time - bg > min_length:
- new_result.append([bg, time])
- # if time - bg < max_length * 1.5:
- # new_result.append([bg, time])
- # else:
- # split_num = int(time - bg) // max_length + 1
- # spl_l = int(time - bg) // split_num
- # for j in range(split_num):
- # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
- bg = time
- new_result.append([bg, time_step[-1]])
- return new_result
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+ speech_i = speech[0, bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ intervals = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ intervals.append([bed_idx // 16, end_idx // 16])
+
+ return speech_list, speech_lengths_list, intervals
+
+
+def merge_vad(vad_result, max_length=15000, min_length=0):
+ new_result = []
+ if len(vad_result) <= 1:
+ return vad_result
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+ time_step = sorted(list(set(time_step)))
+ if len(time_step) == 0:
+ return []
+ bg = 0
+ for i in range(len(time_step) - 1):
+ time = time_step[i]
+ if time_step[i + 1] - bg < max_length:
+ continue
+ if time - bg > min_length:
+ new_result.append([bg, time])
+ # if time - bg < max_length * 1.5:
+ # new_result.append([bg, time])
+ # else:
+ # split_num = int(time - bg) // max_length + 1
+ # spl_l = int(time - bg) // split_num
+ # for j in range(split_num):
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+ bg = time
+ new_result.append([bg, time_step[-1]])
+ return new_result
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
index 9772168f5136806a1fac4b8ab5cfeefada875b2d..6ce8c4d8dd0fd63e8039822adb4424a38d8e80fe 100644
--- a/tools/smart_pad.py
+++ b/tools/smart_pad.py
@@ -1,47 +1,60 @@
-import random
-from multiprocessing import Pool
-from pathlib import Path
-
-import click
-import librosa
-import torch.nn.functional as F
-import torchaudio
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files
-
-threshold = 10 ** (-50 / 20.0)
-
-
-def process(file):
- waveform, sample_rate = torchaudio.load(str(file), backend="sox")
- loudness = librosa.feature.rms(
- y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
- )[0]
- for i in range(len(loudness) - 1, 0, -1):
- if loudness[i] > threshold:
- break
-
- silent_time = (len(loudness) - i) * 512 / sample_rate
-
- if silent_time <= 0.3:
- random_time = random.uniform(0.3, 0.7)
- waveform = F.pad(
- waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
- )
-
- torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
-
-
-@click.command()
-@click.argument("source", type=Path)
-@click.option("--num-workers", type=int, default=12)
-def main(source, num_workers):
- files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
-
- with Pool(num_workers) as p:
- list(tqdm(p.imap_unordered(process, files), total=len(files)))
-
-
-if __name__ == "__main__":
- main()
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+ if waveform.size(0) > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+
+ loudness = librosa.feature.rms(
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+ )[0]
+
+ for i in range(len(loudness) - 1, 0, -1):
+ if loudness[i] > threshold:
+ break
+
+ end_silent_time = (len(loudness) - i) * 512 / sample_rate
+
+ if end_silent_time <= 0.3:
+ random_time = random.uniform(0.3, 0.7) - end_silent_time
+ waveform = F.pad(
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+ )
+
+ for i in range(len(loudness)):
+ if loudness[i] > threshold:
+ break
+
+ start_silent_time = i * 512 / sample_rate
+
+ if start_silent_time > 0.02:
+ waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
+
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+ with Pool(num_workers) as p:
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
index d24a5f39566c47ea0cb1fc506d463e9c95c3efbc..bdb914c2a3b13d13a79882be3e7b33027ce5109a 100644
--- a/tools/vqgan/create_train_split.py
+++ b/tools/vqgan/create_train_split.py
@@ -1,83 +1,83 @@
-import math
-from pathlib import Path
-from random import Random
-
-import click
-from loguru import logger
-from pydub import AudioSegment
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
-
-
-@click.command()
-@click.argument("root", type=click.Path(exists=True, path_type=Path))
-@click.option("--val-ratio", type=float, default=None)
-@click.option("--val-count", type=int, default=None)
-@click.option("--filelist", default=None, type=Path)
-@click.option("--min-duration", default=None, type=float)
-@click.option("--max-duration", default=None, type=float)
-def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
- if filelist:
- files = [i[0] for i in load_filelist(filelist)]
- else:
- files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
-
- if min_duration is None and max_duration is None:
- filtered_files = list(map(str, [file.relative_to(root) for file in files]))
- else:
- filtered_files = []
- for file in tqdm(files):
- try:
- audio = AudioSegment.from_file(str(file))
- duration = len(audio) / 1000.0
-
- if min_duration is not None and duration < min_duration:
- logger.info(
- f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
- )
- continue
-
- if max_duration is not None and duration > max_duration:
- logger.info(
- f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
- )
- continue
-
- filtered_files.append(str(file.relative_to(root)))
- except Exception as e:
- logger.info(f"Error processing {file}: {e}")
-
- logger.info(
- f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
- )
-
- Random(42).shuffle(filtered_files)
-
- if val_count is None and val_ratio is None:
- logger.info("Validation ratio and count not specified, using min(20%, 100)")
- val_size = min(100, math.ceil(len(filtered_files) * 0.2))
- elif val_count is not None and val_ratio is not None:
- logger.error("Cannot specify both val_count and val_ratio")
- return
- elif val_count is not None:
- if val_count < 1 or val_count > len(filtered_files):
- logger.error("val_count must be between 1 and number of files")
- return
- val_size = val_count
- else:
- val_size = math.ceil(len(filtered_files) * val_ratio)
-
- logger.info(f"Using {val_size} files for validation")
-
- with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
- f.write("\n".join(filtered_files[val_size:]))
-
- with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
- f.write("\n".join(filtered_files[:val_size]))
-
- logger.info("Done")
-
-
-if __name__ == "__main__":
- main()
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ if min_duration is None and max_duration is None:
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
+ else:
+ filtered_files = []
+ for file in tqdm(files):
+ try:
+ audio = AudioSegment.from_file(str(file))
+ duration = len(audio) / 1000.0
+
+ if min_duration is not None and duration < min_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+ )
+ continue
+
+ if max_duration is not None and duration > max_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+ )
+ continue
+
+ filtered_files.append(str(file.relative_to(root)))
+ except Exception as e:
+ logger.info(f"Error processing {file}: {e}")
+
+ logger.info(
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+ )
+
+ Random(42).shuffle(filtered_files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(filtered_files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(filtered_files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
index c24eb3f46ab57fb02930f233a67299cb31c7d7ba..bccc721a7f5d5cb68596df62d7b4c629814538c5 100644
--- a/tools/vqgan/extract_vq.py
+++ b/tools/vqgan/extract_vq.py
@@ -1,227 +1,233 @@
-import os
-import subprocess as sp
-import sys
-import time
-from datetime import timedelta
-from functools import lru_cache
-from pathlib import Path
-from random import Random
-
-import click
-import numpy as np
-import torch
-import torchaudio
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from lightning import LightningModule
-from loguru import logger
-from omegaconf import OmegaConf
-
-from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-# This file is used to convert the audio files to text files using the Whisper model.
-# It's mainly used to generate the training data for the VQ model.
-
-
-RANK = int(os.environ.get("SLURM_PROCID", 0))
-WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
-
-logger_format = (
- "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
- "{level: <8} | "
- "{name}:{function}:{line} | "
- "{extra[rank]} - {message}"
-)
-logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
-logger.remove()
-logger.add(sys.stderr, format=logger_format)
-
-
-@lru_cache(maxsize=1)
-def get_model(
- config_name: str = "firefly_gan_vq",
- checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- device: str | torch.device = "cuda",
-):
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
- cfg = compose(config_name=config_name)
-
- model = instantiate(cfg)
- state_dict = torch.load(
- checkpoint_path,
- map_location=device,
- )
- if "state_dict" in state_dict:
- state_dict = state_dict["state_dict"]
-
- if any("generator" in k for k in state_dict):
- state_dict = {
- k.replace("generator.", ""): v
- for k, v in state_dict.items()
- if "generator." in k
- }
-
- model.load_state_dict(state_dict, strict=False)
- model.eval()
- model.to(device)
-
- logger.info(f"Loaded model")
- return model
-
-
-@torch.inference_mode()
-def process_batch(files: list[Path], model) -> float:
- wavs = []
- audio_lengths = []
- new_files = []
- max_length = total_time = 0
-
- for file in files:
- try:
- wav, sr = torchaudio.load(
- str(file), backend="sox" if sys.platform == "linux" else "soundfile"
- ) # Need to install libsox-dev
- except Exception as e:
- logger.error(f"Error reading {file}: {e}")
- continue
-
- if wav.shape[0] > 1:
- wav = wav.mean(dim=0, keepdim=True)
-
- wav = torchaudio.functional.resample(
- wav.cuda(), sr, model.spec_transform.sample_rate
- )[0]
- total_time += len(wav) / model.spec_transform.sample_rate
- max_length = max(max_length, len(wav))
-
- wavs.append(wav)
- audio_lengths.append(len(wav))
- new_files.append(file)
-
- files = new_files
-
- # Pad to max length
- for i, wav in enumerate(wavs):
- wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
-
- audios = torch.stack(wavs, dim=0)[:, None]
- audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
-
- # Calculate lengths
- indices, feature_lengths = model.encode(audios, audio_lengths)
-
- # Save to disk
- outputs = indices.cpu().numpy()
-
- for file, length, feature, audio_length in zip(
- files, feature_lengths, outputs, audio_lengths
- ):
- feature = feature[:, :length]
-
- # (T,)
- with open(file.with_suffix(".npy"), "wb") as f:
- np.save(f, feature)
-
- return total_time
-
-
-@click.command()
-@click.argument("folder")
-@click.option("--num-workers", default=1)
-@click.option("--config-name", default="firefly_gan_vq")
-@click.option(
- "--checkpoint-path",
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-)
-@click.option("--batch-size", default=64)
-@click.option("--filelist", default=None, type=Path)
-def main(
- folder: str,
- num_workers: int,
- config_name: str,
- checkpoint_path: str,
- batch_size: int,
- filelist: Path,
-):
- if num_workers > 1 and WORLD_SIZE != num_workers:
- assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
-
- logger.info(f"Spawning {num_workers} workers")
-
- if torch.cuda.is_available():
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- if visible_devices is None:
- visible_devices = list(range(torch.cuda.device_count()))
- else:
- visible_devices = visible_devices.split(",")
- else:
- # Set to empty string to avoid using GPU
- visible_devices = [""]
-
- processes = []
- for i in range(num_workers):
- env = os.environ.copy()
- env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
- env["SLURM_PROCID"] = str(i)
- env["SLURM_NTASKS"] = str(num_workers)
-
- processes.append(
- sp.Popen(
- [sys.executable] + sys.argv.copy(),
- env=env,
- )
- )
-
- for p in processes:
- p.wait()
-
- logger.info(f"All workers finished")
- return
-
- # This is a worker
- logger.info(f"Starting worker")
- if filelist:
- files = [i[0] for i in load_filelist(filelist)]
- else:
- files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
-
- print(f"Found {len(files)} files")
- files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
-
- total_files = len(files)
- files = files[RANK::WORLD_SIZE]
- logger.info(f"Processing {len(files)}/{total_files} files")
-
- # Batch processing
- total_time = 0
- begin_time = time.time()
- processed_files = 0
- model = get_model(config_name, checkpoint_path)
-
- for n_batch, idx in enumerate(range(0, len(files), batch_size)):
- batch = files[idx : idx + batch_size]
- batch_time = process_batch(batch, model)
-
- total_time += batch_time
- processed_files += len(batch)
-
- if (n_batch + 1) % 10 == 0:
- eta = (
- (time.time() - begin_time)
- / processed_files
- * (len(files) - processed_files)
- )
- logger.info(
- f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
- + f"ETA: {timedelta(seconds=round(eta))}s"
- )
-
- logger.info(
- f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
- )
-
-
-if __name__ == "__main__":
- main()
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+backends = torchaudio.list_audio_backends()
+
+if "ffmpeg" in backends:
+ backend = "ffmpeg"
+else:
+ backend = "soundfile"
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name}:{function}:{line} | "
+ "{extra[rank]} - {message}"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "firefly_gan_vq",
+ checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ device: str | torch.device = "cuda",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend=backend
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(
+ wav.cuda(), sr, model.spec_transform.sample_rate
+ )[0]
+ total_time += len(wav) / model.spec_transform.sample_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ if torch.cuda.is_available():
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+ else:
+ # Set to empty string to avoid using GPU
+ visible_devices = [""]
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py
index b6bc7531c41455c346109bdaaa43dafc1e3508a4..8cf77dd1bdabae29fc0d1b307e4fed74f068bf6d 100644
--- a/tools/vqgan/inference.py
+++ b/tools/vqgan/inference.py
@@ -1,122 +1,121 @@
-from pathlib import Path
-
-import click
-import hydra
-import numpy as np
-import soundfile as sf
-import torch
-import torchaudio
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from loguru import logger
-from omegaconf import OmegaConf
-
-from tools.file import AUDIO_EXTENSIONS
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-
-def load_model(config_name, checkpoint_path, device="cuda"):
- hydra.core.global_hydra.GlobalHydra.instance().clear()
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
- cfg = compose(config_name=config_name)
-
- model = instantiate(cfg)
- state_dict = torch.load(
- checkpoint_path,
- map_location=device,
- )
- if "state_dict" in state_dict:
- state_dict = state_dict["state_dict"]
-
- if any("generator" in k for k in state_dict):
- state_dict = {
- k.replace("generator.", ""): v
- for k, v in state_dict.items()
- if "generator." in k
- }
-
- result = model.load_state_dict(state_dict, strict=False)
- model.eval()
- model.to(device)
-
- logger.info(f"Loaded model: {result}")
- return model
-
-
-@torch.no_grad()
-@click.command()
-@click.option(
- "--input-path",
- "-i",
- default="test.wav",
- type=click.Path(exists=True, path_type=Path),
-)
-@click.option(
- "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
-)
-@click.option("--config-name", default="firefly_gan_vq")
-@click.option(
- "--checkpoint-path",
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-)
-@click.option(
- "--device",
- "-d",
- default="cuda",
-)
-def main(input_path, output_path, config_name, checkpoint_path, device):
- model = load_model(config_name, checkpoint_path, device=device)
-
- if input_path.suffix in AUDIO_EXTENSIONS:
- logger.info(f"Processing in-place reconstruction of {input_path}")
-
- # Load audio
- audio, sr = torchaudio.load(str(input_path))
- if audio.shape[0] > 1:
- audio = audio.mean(0, keepdim=True)
- audio = torchaudio.functional.resample(
- audio, sr, model.spec_transform.sample_rate
- )
-
- audios = audio[None].to(device)
- logger.info(
- f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
- )
-
- # VQ Encoder
- audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
- indices = model.encode(audios, audio_lengths)[0][0]
-
- logger.info(f"Generated indices of shape {indices.shape}")
-
- # Save indices
- np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
- elif input_path.suffix == ".npy":
- logger.info(f"Processing precomputed indices from {input_path}")
- indices = np.load(input_path)
- indices = torch.from_numpy(indices).to(device).long()
- assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
- else:
- raise ValueError(f"Unknown input type: {input_path}")
-
- # Restore
- feature_lengths = torch.tensor([indices.shape[1]], device=device)
- fake_audios, _ = model.decode(
- indices=indices[None], feature_lengths=feature_lengths
- )
- audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
-
- logger.info(
- f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
- )
-
- # Save audio
- fake_audio = fake_audios[0, 0].float().cpu().numpy()
- sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
- logger.info(f"Saved audio to {output_path}")
-
-
-if __name__ == "__main__":
- main()
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model: {result}")
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+
+ # Load audio
+ audio, sr = torchaudio.load(str(input_path))
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(
+ audio, sr, model.spec_transform.sample_rate
+ )
+
+ audios = audio[None].to(device)
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+ indices = model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
+ fake_audios, _ = model.decode(
+ indices=indices[None], feature_lengths=feature_lengths
+ )
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/webui.py b/tools/webui.py
index 6271ff175de7de94d810914264c8c3677a28b12f..0fadc927fe8e6bdb12084dcb5a22f43da582bb7e 100644
--- a/tools/webui.py
+++ b/tools/webui.py
@@ -1,619 +1,553 @@
-import gc
-import html
-import io
-import os
-import queue
-import wave
-from argparse import ArgumentParser
-from functools import partial
-from pathlib import Path
-
-import gradio as gr
-import librosa
-import numpy as np
-import pyrootutils
-import torch
-from loguru import logger
-from transformers import AutoTokenizer
-
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
-
-
-from fish_speech.i18n import i18n
-from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-from fish_speech.utils import autocast_exclude_mps
-from tools.api import decode_vq_tokens, encode_reference
-from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
-from tools.llama.generate import (
- GenerateRequest,
- GenerateResponse,
- WrappedGenerateResponse,
- launch_thread_safe_queue,
-)
-from tools.vqgan.inference import load_model as load_decoder_model
-
-# Make einx happy
-os.environ["EINX_FILTER_TRACEBACK"] = "false"
-
-
-HEADER_MD = f"""# Fish Speech
-
-{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
-
-{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
-
-{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
-
-{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
-"""
-
-TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
-SPACE_IMPORTED = False
-
-
-def build_html_error_message(error):
- return f"""
-
- {html.escape(str(error))}
-
- """
-
-
-@torch.inference_mode()
-def inference(
- text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- streaming=False,
-):
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
- return (
- None,
- None,
- i18n("Text is too long, please keep it under {} characters.").format(
- args.max_gradio_length
- ),
- )
-
- # Parse reference audio aka prompt
- prompt_tokens = encode_reference(
- decoder_model=decoder_model,
- reference_audio=reference_audio,
- enable_reference_audio=enable_reference_audio,
- )
-
- # LLAMA Inference
- request = dict(
- device=decoder_model.device,
- max_new_tokens=max_new_tokens,
- text=text,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- temperature=temperature,
- compile=args.compile,
- iterative_prompt=chunk_length > 0,
- chunk_length=chunk_length,
- max_length=2048,
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
- prompt_text=reference_text if enable_reference_audio else None,
- )
-
- response_queue = queue.Queue()
- llama_queue.put(
- GenerateRequest(
- request=request,
- response_queue=response_queue,
- )
- )
-
- if streaming:
- yield wav_chunk_header(), None, None
-
- segments = []
-
- while True:
- result: WrappedGenerateResponse = response_queue.get()
- if result.status == "error":
- yield None, None, build_html_error_message(result.response)
- break
-
- result: GenerateResponse = result.response
- if result.action == "next":
- break
-
- with autocast_exclude_mps(
- device_type=decoder_model.device.type, dtype=args.precision
- ):
- fake_audios = decode_vq_tokens(
- decoder_model=decoder_model,
- codes=result.codes,
- )
-
- fake_audios = fake_audios.float().cpu().numpy()
- segments.append(fake_audios)
-
- if streaming:
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
-
- if len(segments) == 0:
- return (
- None,
- None,
- build_html_error_message(
- i18n("No audio generated, please check the input text.")
- ),
- )
-
- # No matter streaming or not, we need to return the final audio
- audio = np.concatenate(segments, axis=0)
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
-
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
-
-
-def inference_with_auto_rerank(
- text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- use_auto_rerank,
- streaming=False,
-):
-
- max_attempts = 2 if use_auto_rerank else 1
- best_wer = float("inf")
- best_audio = None
- best_sample_rate = None
-
- for attempt in range(max_attempts):
- audio_generator = inference(
- text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- streaming=False,
- )
-
- # 获取音频数据
- for _ in audio_generator:
- pass
- _, (sample_rate, audio), message = _
-
- if audio is None:
- return None, None, message
-
- if not use_auto_rerank:
- return None, (sample_rate, audio), None
-
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
- wer = calculate_wer(text, asr_result["text"])
- if wer <= 0.3 and not asr_result["huge_gap"]:
- return None, (sample_rate, audio), None
-
- if wer < best_wer:
- best_wer = wer
- best_audio = audio
- best_sample_rate = sample_rate
-
- if attempt == max_attempts - 1:
- break
-
- return None, (best_sample_rate, best_audio), None
-
-
-inference_stream = partial(inference, streaming=True)
-
-n_audios = 4
-
-global_audio_list = []
-global_error_list = []
-
-
-def inference_wrapper(
- text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- batch_infer_num,
- if_load_asr_model,
-):
- audios = []
- errors = []
-
- for _ in range(batch_infer_num):
- result = inference_with_auto_rerank(
- text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- if_load_asr_model,
- )
-
- _, audio_data, error_message = result
-
- audios.append(
- gr.Audio(value=audio_data if audio_data else None, visible=True),
- )
- errors.append(
- gr.HTML(value=error_message if error_message else None, visible=True),
- )
-
- for _ in range(batch_infer_num, n_audios):
- audios.append(
- gr.Audio(value=None, visible=False),
- )
- errors.append(
- gr.HTML(value=None, visible=False),
- )
-
- return None, *audios, *errors
-
-
-def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
- buffer = io.BytesIO()
-
- with wave.open(buffer, "wb") as wav_file:
- wav_file.setnchannels(channels)
- wav_file.setsampwidth(bit_depth // 8)
- wav_file.setframerate(sample_rate)
-
- wav_header_bytes = buffer.getvalue()
- buffer.close()
- return wav_header_bytes
-
-
-def normalize_text(user_input, use_normalization):
- if use_normalization:
- return ChnNormedText(raw_text=user_input).normalize()
- else:
- return user_input
-
-
-asr_model = None
-
-
-def change_if_load_asr_model(if_load):
- global asr_model
-
- if if_load:
- gr.Warning("Loading faster whisper model...")
- if asr_model is None:
- asr_model = load_model()
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
-
- if if_load is False:
- gr.Warning("Unloading faster whisper model...")
- del asr_model
- asr_model = None
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
-
-
-def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
- if if_load and asr_model is not None:
- if (
- if_auto_label
- and enable_ref
- and ref_audio is not None
- and ref_text.strip() == ""
- ):
- data, sample_rate = librosa.load(ref_audio)
- res = batch_asr(asr_model, [data], sample_rate)[0]
- ref_text = res["text"]
- else:
- gr.Warning("Whisper model not loaded!")
-
- return gr.Textbox(value=ref_text)
-
-
-def build_app():
- with gr.Blocks(theme=gr.themes.Base()) as app:
- gr.Markdown(HEADER_MD)
-
- # Use light theme by default
- app.load(
- None,
- None,
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
- % args.theme,
- )
-
- # Inference
- with gr.Row():
- with gr.Column(scale=3):
- text = gr.Textbox(
- label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
- )
- refined_text = gr.Textbox(
- label=i18n("Realtime Transform Text"),
- placeholder=i18n(
- "Normalization Result Preview (Currently Only Chinese)"
- ),
- lines=5,
- interactive=False,
- )
-
- with gr.Row():
- if_refine_text = gr.Checkbox(
- label=i18n("Text Normalization"),
- value=False,
- scale=1,
- )
-
- if_load_asr_model = gr.Checkbox(
- label=i18n("Load / Unload ASR model for auto-reranking"),
- value=False,
- scale=3,
- )
-
- with gr.Row():
- with gr.Tab(label=i18n("Advanced Config")):
- chunk_length = gr.Slider(
- label=i18n("Iterative Prompt Length, 0 means off"),
- minimum=50,
- maximum=300,
- value=200,
- step=8,
- )
-
- max_new_tokens = gr.Slider(
- label=i18n("Maximum tokens per batch, 0 means no limit"),
- minimum=0,
- maximum=2048,
- value=1024, # 0 means no limit
- step=8,
- )
-
- top_p = gr.Slider(
- label="Top-P",
- minimum=0.6,
- maximum=0.9,
- value=0.7,
- step=0.01,
- )
-
- repetition_penalty = gr.Slider(
- label=i18n("Repetition Penalty"),
- minimum=1,
- maximum=1.5,
- value=1.2,
- step=0.01,
- )
-
- temperature = gr.Slider(
- label="Temperature",
- minimum=0.6,
- maximum=0.9,
- value=0.7,
- step=0.01,
- )
-
- with gr.Tab(label=i18n("Reference Audio")):
- gr.Markdown(
- i18n(
- "5 to 10 seconds of reference audio, useful for specifying speaker."
- )
- )
-
- enable_reference_audio = gr.Checkbox(
- label=i18n("Enable Reference Audio"),
- )
- reference_audio = gr.Audio(
- label=i18n("Reference Audio"),
- type="filepath",
- )
- with gr.Row():
- if_auto_label = gr.Checkbox(
- label=i18n("Auto Labeling"),
- min_width=100,
- scale=0,
- value=False,
- )
- reference_text = gr.Textbox(
- label=i18n("Reference Text"),
- lines=1,
- placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
- value="",
- )
- with gr.Tab(label=i18n("Batch Inference")):
- batch_infer_num = gr.Slider(
- label="Batch infer nums",
- minimum=1,
- maximum=n_audios,
- step=1,
- value=1,
- )
-
- with gr.Column(scale=3):
- for _ in range(n_audios):
- with gr.Row():
- error = gr.HTML(
- label=i18n("Error Message"),
- visible=True if _ == 0 else False,
- )
- global_error_list.append(error)
- with gr.Row():
- audio = gr.Audio(
- label=i18n("Generated Audio"),
- type="numpy",
- interactive=False,
- visible=True if _ == 0 else False,
- )
- global_audio_list.append(audio)
-
- with gr.Row():
- stream_audio = gr.Audio(
- label=i18n("Streaming Audio"),
- streaming=True,
- autoplay=True,
- interactive=False,
- show_download_button=True,
- )
- with gr.Row():
- with gr.Column(scale=3):
- generate = gr.Button(
- value="\U0001F3A7 " + i18n("Generate"), variant="primary"
- )
- generate_stream = gr.Button(
- value="\U0001F3A7 " + i18n("Streaming Generate"),
- variant="primary",
- )
-
- text.input(
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
- )
-
- if_load_asr_model.change(
- fn=change_if_load_asr_model,
- inputs=[if_load_asr_model],
- outputs=[if_load_asr_model],
- )
-
- if_auto_label.change(
- fn=lambda: gr.Textbox(value=""),
- inputs=[],
- outputs=[reference_text],
- ).then(
- fn=change_if_auto_label,
- inputs=[
- if_load_asr_model,
- if_auto_label,
- enable_reference_audio,
- reference_audio,
- reference_text,
- ],
- outputs=[reference_text],
- )
-
- # # Submit
- generate.click(
- inference_wrapper,
- [
- refined_text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- batch_infer_num,
- if_load_asr_model,
- ],
- [stream_audio, *global_audio_list, *global_error_list],
- concurrency_limit=1,
- )
-
- generate_stream.click(
- inference_stream,
- [
- refined_text,
- enable_reference_audio,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- ],
- [stream_audio, global_audio_list[0], global_error_list[0]],
- concurrency_limit=10,
- )
- return app
-
-
-def parse_args():
- parser = ArgumentParser()
- parser.add_argument(
- "--llama-checkpoint-path",
- type=Path,
- default="checkpoints/fish-speech-1.4",
- )
- parser.add_argument(
- "--decoder-checkpoint-path",
- type=Path,
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- )
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
- parser.add_argument("--device", type=str, default="cuda")
- parser.add_argument("--half", action="store_true")
- parser.add_argument("--compile", action="store_true")
- parser.add_argument("--max-gradio-length", type=int, default=0)
- parser.add_argument("--theme", type=str, default="light")
-
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- args = parse_args()
- args.precision = torch.half if args.half else torch.bfloat16
-
- logger.info("Loading Llama model...")
- llama_queue = launch_thread_safe_queue(
- checkpoint_path=args.llama_checkpoint_path,
- device=args.device,
- precision=args.precision,
- compile=args.compile,
- )
- logger.info("Llama model loaded, loading VQ-GAN model...")
-
- decoder_model = load_decoder_model(
- config_name=args.decoder_config_name,
- checkpoint_path=args.decoder_checkpoint_path,
- device=args.device,
- )
-
- logger.info("Decoder model loaded, warming up...")
-
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
- list(
- inference(
- text="Hello, world!",
- enable_reference_audio=False,
- reference_audio=None,
- reference_text="",
- max_new_tokens=0,
- chunk_length=100,
- top_p=0.7,
- repetition_penalty=1.2,
- temperature=0.7,
- )
- )
-
- logger.info("Warming up done, launching the web UI...")
-
- app = build_app()
- app.launch(show_api=True)
+import gc
+import html
+import io
+import os
+import queue
+import wave
+from argparse import ArgumentParser
+from functools import partial
+from pathlib import Path
+
+import gradio as gr
+import librosa
+import numpy as np
+import pyrootutils
+import torch
+from loguru import logger
+from transformers import AutoTokenizer
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+
+from fish_speech.i18n import i18n
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from tools.api import decode_vq_tokens, encode_reference
+from tools.file import AUDIO_EXTENSIONS, list_files
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
+
+{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+SPACE_IMPORTED = False
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+@torch.inference_mode()
+def inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed="0",
+ streaming=False,
+):
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+ return (
+ None,
+ None,
+ i18n("Text is too long, please keep it under {} characters.").format(
+ args.max_gradio_length
+ ),
+ )
+
+ seed = int(seed)
+ if seed != 0:
+ set_seed(seed)
+ logger.warning(f"set seed: {seed}")
+
+ # Parse reference audio aka prompt
+ prompt_tokens = encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=reference_audio,
+ enable_reference_audio=enable_reference_audio,
+ )
+
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=max_new_tokens,
+ text=text,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=args.compile,
+ iterative_prompt=chunk_length > 0,
+ chunk_length=chunk_length,
+ max_length=2048,
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
+ prompt_text=reference_text if enable_reference_audio else None,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if streaming:
+ yield wav_chunk_header(), None, None
+
+ segments = []
+
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ yield None, None, build_html_error_message(result.response)
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with autocast_exclude_mps(
+ device_type=decoder_model.device.type, dtype=args.precision
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+ segments.append(fake_audios)
+
+ if streaming:
+ wav_header = wav_chunk_header()
+ audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
+ yield wav_header + audio_data, None, None
+
+ if len(segments) == 0:
+ return (
+ None,
+ None,
+ build_html_error_message(
+ i18n("No audio generated, please check the input text.")
+ ),
+ )
+
+ # No matter streaming or not, we need to return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+inference_stream = partial(inference, streaming=True)
+
+n_audios = 4
+
+global_audio_list = []
+global_error_list = []
+
+
+def inference_wrapper(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ batch_infer_num,
+):
+ audios = []
+ errors = []
+
+ for _ in range(batch_infer_num):
+ result = inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ )
+
+ _, audio_data, error_message = next(result)
+
+ audios.append(
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
+ )
+ errors.append(
+ gr.HTML(value=error_message if error_message else None, visible=True),
+ )
+
+ for _ in range(batch_infer_num, n_audios):
+ audios.append(
+ gr.Audio(value=None, visible=False),
+ )
+ errors.append(
+ gr.HTML(value=None, visible=False),
+ )
+
+ return None, *audios, *errors
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+def normalize_text(user_input, use_normalization):
+ if use_normalization:
+ return ChnNormedText(raw_text=user_input).normalize()
+ else:
+ return user_input
+
+
+def update_examples():
+ examples_dir = Path("references")
+ examples_dir.mkdir(parents=True, exist_ok=True)
+ example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
+ return gr.Dropdown(choices=example_audios + [""])
+
+
+def build_app():
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
+ % args.theme,
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+ refined_text = gr.Textbox(
+ label=i18n("Realtime Transform Text"),
+ placeholder=i18n(
+ "Normalization Result Preview (Currently Only Chinese)"
+ ),
+ lines=5,
+ interactive=False,
+ )
+
+ with gr.Row():
+ if_refine_text = gr.Checkbox(
+ label=i18n("Text Normalization"),
+ value=False,
+ scale=1,
+ )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("Advanced Config")):
+ with gr.Row():
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=50,
+ maximum=300,
+ value=200,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n(
+ "Maximum tokens per batch, 0 means no limit"
+ ),
+ minimum=0,
+ maximum=2048,
+ value=0, # 0 means no limit
+ step=8,
+ )
+
+ with gr.Row():
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.5,
+ value=1.2,
+ step=0.01,
+ )
+
+ with gr.Row():
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+ seed = gr.Textbox(
+ label="Seed",
+ info="0 means randomized inference, otherwise deterministic",
+ placeholder="any 32-bit-integer",
+ value="0",
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ with gr.Row():
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+ with gr.Row():
+ enable_reference_audio = gr.Checkbox(
+ label=i18n("Enable Reference Audio"),
+ )
+
+ with gr.Row():
+ example_audio_dropdown = gr.Dropdown(
+ label=i18n("Select Example Audio"),
+ choices=[""],
+ value="",
+ interactive=True,
+ allow_custom_value=True,
+ )
+ with gr.Row():
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ with gr.Row():
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ lines=1,
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ value="",
+ )
+ with gr.Tab(label=i18n("Batch Inference")):
+ with gr.Row():
+ batch_infer_num = gr.Slider(
+ label="Batch infer nums",
+ minimum=1,
+ maximum=n_audios,
+ step=1,
+ value=1,
+ )
+
+ with gr.Column(scale=3):
+ for _ in range(n_audios):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True if _ == 0 else False,
+ )
+ global_error_list.append(error)
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True if _ == 0 else False,
+ )
+ global_audio_list.append(audio)
+
+ with gr.Row():
+ stream_audio = gr.Audio(
+ label=i18n("Streaming Audio"),
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ show_download_button=True,
+ )
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
+ )
+ generate_stream = gr.Button(
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
+ variant="primary",
+ )
+
+ text.input(
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
+ )
+
+ def select_example_audio(audio_path):
+ audio_path = Path(audio_path)
+ if audio_path.is_file():
+ lab_file = Path(audio_path.with_suffix(".lab"))
+
+ if lab_file.exists():
+ lab_content = lab_file.read_text(encoding="utf-8").strip()
+ else:
+ lab_content = ""
+
+ return str(audio_path), lab_content, True
+ return None, "", False
+
+ # Connect the dropdown to update reference audio and text
+ example_audio_dropdown.change(
+ fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
+ ).then(
+ fn=select_example_audio,
+ inputs=[example_audio_dropdown],
+ outputs=[reference_audio, reference_text, enable_reference_audio],
+ )
+
+ # # Submit
+ generate.click(
+ inference_wrapper,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ batch_infer_num,
+ ],
+ [stream_audio, *global_audio_list, *global_error_list],
+ concurrency_limit=1,
+ )
+
+ generate_stream.click(
+ inference_stream,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ ],
+ [stream_audio, global_audio_list[0], global_error_list[0]],
+ concurrency_limit=1,
+ )
+ return app
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.4",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+ parser.add_argument("--theme", type=str, default="light")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ # Check if CUDA is available
+ if not torch.cuda.is_available():
+ logger.info("CUDA is not available, running on CPU.")
+ args.device = "cpu"
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference(
+ text="Hello, world!",
+ enable_reference_audio=False,
+ reference_audio=None,
+ reference_text="",
+ max_new_tokens=0,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ app = build_app()
+ app.launch(show_api=True)
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
index 42e7de8a185880d3f2afd368d6df3429488465a4..8ac720d2dad94fca5087852928f7e6ce57d1b92d 100644
--- a/tools/whisper_asr.py
+++ b/tools/whisper_asr.py
@@ -1,176 +1,176 @@
-"""
-Used to transcribe all audio files in one folder into another folder.
-e.g.
-Directory structure:
---pre_data_root
-----SP_1
-------01.wav
-------02.wav
-------......
-----SP_2
-------01.wav
-------02.wav
-------......
-Use
-python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
-to transcribe the first speaker.
-
-Use
-python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
-to transcribe the second speaker.
-
-Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
-"""
-
-import re
-from pathlib import Path
-
-import click
-import soundfile as sf
-from faster_whisper import WhisperModel
-from loguru import logger
-from pydub import AudioSegment
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files
-
-
-@click.command()
-@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
-@click.option(
- "--compute-type",
- default="float16",
- help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
-)
-@click.option("--audio-dir", required=True, help="Directory containing audio files")
-@click.option(
- "--save-dir", required=True, help="Directory to save processed audio files"
-)
-@click.option(
- "--sample-rate",
- default=44100,
- type=int,
- help="Output sample rate, default to input sample rate",
-)
-@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
-@click.option("--language", default="auto", help="Language of the transcription")
-@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
-def main(
- model_size,
- compute_type,
- audio_dir,
- save_dir,
- sample_rate,
- device,
- language,
- initial_prompt,
-):
- logger.info("Loading / Downloading Faster Whisper model...")
-
- model = WhisperModel(
- model_size,
- device=device,
- compute_type=compute_type,
- download_root="faster_whisper",
- )
-
- logger.info("Model loaded.")
-
- save_path = Path(save_dir)
- save_path.mkdir(parents=True, exist_ok=True)
-
- audio_files = list_files(
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
- )
-
- for file_path in tqdm(audio_files, desc="Processing audio file"):
- file_stem = file_path.stem
- file_suffix = file_path.suffix
-
- rel_path = Path(file_path).relative_to(audio_dir)
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
-
- audio = AudioSegment.from_file(file_path)
-
- segments, info = model.transcribe(
- file_path,
- beam_size=5,
- language=None if language == "auto" else language,
- initial_prompt=initial_prompt,
- )
-
- print(
- "Detected language '%s' with probability %f"
- % (info.language, info.language_probability)
- )
- print("Total len(ms): ", len(audio))
-
- whole_text = None
- for segment in segments:
- id, start, end, text = (
- segment.id,
- segment.start,
- segment.end,
- segment.text,
- )
- print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
- if not whole_text:
- whole_text = text
- else:
- whole_text += ", " + text
-
- whole_text += "."
-
- audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
- audio.export(audio_save_path, format=file_suffix[1:])
- print(f"Exported {audio_save_path}")
-
- transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
- with open(
- transcript_save_path,
- "w",
- encoding="utf-8",
- ) as f:
- f.write(whole_text)
-
-
-if __name__ == "__main__":
- main()
- exit(0)
-
- audio = AudioSegment.from_wav(
- r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
- )
-
- model_size = "large-v3"
-
- model = WhisperModel(
- model_size,
- device="cuda",
- compute_type="float16",
- download_root="faster_whisper",
- )
-
- segments, info = model.transcribe(
- r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
- beam_size=5,
- )
-
- print(
- "Detected language '%s' with probability %f"
- % (info.language, info.language_probability)
- )
- print("Total len(ms): ", len(audio))
-
- for i, segment in enumerate(segments):
- print(
- "Segment %03d [%.2fs -> %.2fs] %s"
- % (i, segment.start, segment.end, segment.text)
- )
- start_ms = int(segment.start * 1000)
- end_ms = int(segment.end * 1000)
- segment_audio = audio[start_ms:end_ms]
- segment_audio.export(f"segment_{i:03d}.wav", format="wav")
- print(f"Exported segment_{i:03d}.wav")
-
- print("All segments have been exported.")
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import re
+from pathlib import Path
+
+import click
+import soundfile as sf
+from faster_whisper import WhisperModel
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
+@click.option(
+ "--compute-type",
+ default="float16",
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
+)
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=44100,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+ model_size,
+ compute_type,
+ audio_dir,
+ save_dir,
+ sample_rate,
+ device,
+ language,
+ initial_prompt,
+):
+ logger.info("Loading / Downloading Faster Whisper model...")
+
+ model = WhisperModel(
+ model_size,
+ device=device,
+ compute_type=compute_type,
+ download_root="faster_whisper",
+ )
+
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ segments, info = model.transcribe(
+ file_path,
+ beam_size=5,
+ language=None if language == "auto" else language,
+ initial_prompt=initial_prompt,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ whole_text = None
+ for segment in segments:
+ id, start, end, text = (
+ segment.id,
+ segment.start,
+ segment.end,
+ segment.text,
+ )
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
+ if not whole_text:
+ whole_text = text
+ else:
+ whole_text += ", " + text
+
+ whole_text += "."
+
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+ audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(whole_text)
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+
+ audio = AudioSegment.from_wav(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
+ )
+
+ model_size = "large-v3"
+
+ model = WhisperModel(
+ model_size,
+ device="cuda",
+ compute_type="float16",
+ download_root="faster_whisper",
+ )
+
+ segments, info = model.transcribe(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
+ beam_size=5,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for i, segment in enumerate(segments):
+ print(
+ "Segment %03d [%.2fs -> %.2fs] %s"
+ % (i, segment.start, segment.end, segment.text)
+ )
+ start_ms = int(segment.start * 1000)
+ end_ms = int(segment.end * 1000)
+ segment_audio = audio[start_ms:end_ms]
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
+ print(f"Exported segment_{i:03d}.wav")
+
+ print("All segments have been exported.")