import os import sys import yaml import torch import numpy as np import typing as tp from pathlib import Path from hashlib import sha256 now_dir = os.getcwd() sys.path.append(now_dir) from main.configs.config import Config from main.library.uvr5_separator import spec_utils from main.library.uvr5_separator.demucs.hdemucs import HDemucs from main.library.uvr5_separator.demucs.states import load_model from main.library.uvr5_separator.demucs.apply import BagOfModels, Model from main.library.uvr5_separator.common_separator import CommonSeparator from main.library.uvr5_separator.demucs.apply import apply_model, demucs_segments translations = Config().translations DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"] DEMUCS_2_SOURCE_MAPPER = { CommonSeparator.INST_STEM: 0, CommonSeparator.VOCAL_STEM: 1 } DEMUCS_4_SOURCE_MAPPER = { CommonSeparator.BASS_STEM: 0, CommonSeparator.DRUM_STEM: 1, CommonSeparator.OTHER_STEM: 2, CommonSeparator.VOCAL_STEM: 3 } DEMUCS_6_SOURCE_MAPPER = { CommonSeparator.BASS_STEM: 0, CommonSeparator.DRUM_STEM: 1, CommonSeparator.OTHER_STEM: 2, CommonSeparator.VOCAL_STEM: 3, CommonSeparator.GUITAR_STEM: 4, CommonSeparator.PIANO_STEM: 5, } REMOTE_ROOT = Path(__file__).parent / "remote" PRETRAINED_MODELS = { "demucs": "e07c671f", "demucs48_hq": "28a1282c", "demucs_extra": "3646af93", "demucs_quantized": "07afea75", "tasnet": "beb46fac", "tasnet_extra": "df3777b2", "demucs_unittest": "09ebc15f", } sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator")) AnyModel = tp.Union[Model, BagOfModels] class DemucsSeparator(CommonSeparator): def __init__(self, common_config, arch_config): super().__init__(config=common_config) self.segment_size = arch_config.get("segment_size", "Default") self.shifts = arch_config.get("shifts", 2) self.overlap = arch_config.get("overlap", 0.25) self.segments_enabled = arch_config.get("segments_enabled", True) self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled)) self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap)) self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER self.audio_file_path = None self.audio_file_base = None self.demucs_model_instance = None self.logger.info(translations["start_demucs"]) def separate(self, audio_file_path): self.logger.debug(translations["start_separator"]) source = None stem_source = None inst_source = {} self.audio_file_path = audio_file_path self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] self.logger.debug(translations["prepare_mix"]) mix = self.prepare_mix(self.audio_file_path) self.logger.debug(translations["demix"].format(shape=mix.shape)) self.logger.debug(translations["cancel_mix"]) self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE) self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path))) self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance) self.demucs_model_instance.to(self.torch_device) self.demucs_model_instance.eval() self.logger.debug(translations["model_review"]) source = self.demix_demucs(mix) del self.demucs_model_instance self.clear_gpu_cache() self.logger.debug(translations["del_gpu_cache_after_demix"]) output_files = [] self.logger.debug(translations["process_output_file"]) if isinstance(inst_source, np.ndarray): self.logger.debug(translations["process_ver"]) source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]]) inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape source = inst_source if isinstance(source, np.ndarray): source_length = len(source) self.logger.debug(translations["source_length"].format(source_length=source_length)) match source_length: case 2: self.logger.debug(translations["set_map"].format(part="2")) self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER case 6: self.logger.debug(translations["set_map"].format(part="6")) self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER case _: self.logger.debug(translations["set_map"].format(part="2")) self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER self.logger.debug(translations["process_all_part"]) for stem_name, stem_value in self.demucs_source_map.items(): if self.output_single_stem is not None: if stem_name.lower() != self.output_single_stem.lower(): self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem)) continue stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}") stem_source = source[stem_value].T self.final_process(stem_path, stem_source, stem_name) output_files.append(stem_path) return output_files def demix_demucs(self, mix): self.logger.debug(translations["starting_demix_demucs"]) processed = {} mix = torch.tensor(mix, dtype=torch.float32) ref = mix.mean(0) mix = (mix - ref.mean()) / ref.std() mix_infer = mix with torch.no_grad(): self.logger.debug(translations["model_infer"]) sources = apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0] sources = (sources * ref.std() + ref.mean()).cpu().numpy() sources[[0, 1]] = sources[[1, 0]] processed[mix] = sources[:, :, 0:None].copy() sources = list(processed.values()) sources = [s[:, :, 0:None] for s in sources] sources = np.concatenate(sources, axis=-1) return sources class ModelOnlyRepo: def has_model(self, sig: str) -> bool: raise NotImplementedError() def get_model(self, sig: str) -> Model: raise NotImplementedError() class RemoteRepo(ModelOnlyRepo): def __init__(self, models: tp.Dict[str, str]): self._models = models def has_model(self, sig: str) -> bool: return sig in self._models def get_model(self, sig: str) -> Model: try: url = self._models[sig] except KeyError: raise RuntimeError(translations["not_found_model_signature"].format(sig=sig)) pkg = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True) return load_model(pkg) class LocalRepo(ModelOnlyRepo): def __init__(self, root: Path): self.root = root self.scan() def scan(self): self._models = {} self._checksums = {} for file in self.root.iterdir(): if file.suffix == ".th": if "-" in file.stem: xp_sig, checksum = file.stem.split("-") self._checksums[xp_sig] = checksum else: xp_sig = file.stem if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig)) self._models[xp_sig] = file def has_model(self, sig: str) -> bool: return sig in self._models def get_model(self, sig: str) -> Model: try: file = self._models[sig] except KeyError: raise RuntimeError(translations["not_found_model_signature"].format(sig=sig)) if sig in self._checksums: check_checksum(file, self._checksums[sig]) return load_model(file) class BagOnlyRepo: def __init__(self, root: Path, model_repo: ModelOnlyRepo): self.root = root self.model_repo = model_repo self.scan() def scan(self): self._bags = {} for file in self.root.iterdir(): if file.suffix == ".yaml": self._bags[file.stem] = file def has_model(self, name: str) -> bool: return name in self._bags def get_model(self, name: str) -> BagOfModels: try: yaml_file = self._bags[name] except KeyError: raise RuntimeError(translations["name_not_pretrained"].format(name=name)) bag = yaml.safe_load(open(yaml_file)) signatures = bag["models"] models = [self.model_repo.get_model(sig) for sig in signatures] weights = bag.get("weights") segment = bag.get("segment") return BagOfModels(models, weights, segment) class AnyModelRepo: def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): self.model_repo = model_repo self.bag_repo = bag_repo def has_model(self, name_or_sig: str) -> bool: return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) def get_model(self, name_or_sig: str) -> AnyModel: if self.model_repo.has_model(name_or_sig): return self.model_repo.get_model(name_or_sig) else: return self.bag_repo.get_model(name_or_sig) def check_checksum(path: Path, checksum: str): sha = sha256() with open(path, "rb") as file: while 1: buf = file.read(2**20) if not buf: break sha.update(buf) actual_checksum = sha.hexdigest()[: len(checksum)] if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum)) def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: root: str = "" models: tp.Dict[str, str] = {} for line in remote_file_list.read_text().split("\n"): line = line.strip() if line.startswith("#"): continue elif line.startswith("root:"): root = line.split(":", 1)[1].strip() else: sig = line.split("-", 1)[0] assert sig not in models models[sig] = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" + root + line return models def get_demucs_model(name: str, repo: tp.Optional[Path] = None): if name == "demucs_unittest": return HDemucs(channels=4, sources=DEMUCS_4_SOURCE) model_repo: ModelOnlyRepo if repo is None: models = _parse_remote_files(REMOTE_ROOT / "files.txt") model_repo = RemoteRepo(models) bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) else: if not repo.is_dir(): print(translations["repo_must_be_folder"].format(repo=repo)) model_repo = LocalRepo(repo) bag_repo = BagOnlyRepo(repo, model_repo) any_repo = AnyModelRepo(model_repo, bag_repo) model = any_repo.get_model(name) model.eval() return model