import tqdm import torch import random import typing as tp from torch import nn from torch.nn import functional as F from concurrent.futures import ThreadPoolExecutor from .demucs import Demucs from .hdemucs import HDemucs from .htdemucs import HTDemucs from .utils import center_trim Model = tp.Union[Demucs, HDemucs, HTDemucs] class DummyPoolExecutor: class DummyResult: def __init__(self, func, *args, **kwargs): self.func = func self.args = args self.kwargs = kwargs def result(self): return self.func(*self.args, **self.kwargs) def __init__(self, workers=0): pass def submit(self, func, *args, **kwargs): return DummyPoolExecutor.DummyResult(func, *args, **kwargs) def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): return class BagOfModels(nn.Module): def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None): super().__init__() assert len(models) > 0 first = models[0] for other in models: assert other.sources == first.sources assert other.samplerate == first.samplerate assert other.audio_channels == first.audio_channels if segment is not None: other.segment = segment self.audio_channels = first.audio_channels self.samplerate = first.samplerate self.sources = first.sources self.models = nn.ModuleList(models) if weights is None: weights = [[1.0 for _ in first.sources] for _ in models] else: assert len(weights) == len(models) for weight in weights: assert len(weight) == len(first.sources) self.weights = weights def forward(self, x): raise NotImplementedError("`apply_model`") class TensorChunk: def __init__(self, tensor, offset=0, length=None): total_length = tensor.shape[-1] assert offset >= 0 assert offset < total_length length = total_length - offset if length is None else min(total_length - offset, length) if isinstance(tensor, TensorChunk): self.tensor = tensor.tensor self.offset = offset + tensor.offset else: self.tensor = tensor self.offset = offset self.length = length self.device = tensor.device @property def shape(self): shape = list(self.tensor.shape) shape[-1] = self.length return shape def padded(self, target_length): delta = target_length - self.length total_length = self.tensor.shape[-1] assert delta >= 0 start = self.offset - delta // 2 end = start + target_length correct_start = max(0, start) correct_end = min(total_length, end) pad_left = correct_start - start pad_right = end - correct_end out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) assert out.shape[-1] == target_length return out def tensor_chunk(tensor_or_chunk): if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk else: assert isinstance(tensor_or_chunk, torch.Tensor) return TensorChunk(tensor_or_chunk) def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None): global fut_length global bag_num global prog_bar device = mix.device if device is None else torch.device(device) if pool is None: pool = ThreadPoolExecutor(num_workers) if num_workers > 0 and device.type == "cpu" else DummyPoolExecutor() kwargs = { "shifts": shifts, "split": split, "overlap": overlap, "transition_power": transition_power, "progress": progress, "device": device, "pool": pool, "set_progress_bar": set_progress_bar, "static_shifts": static_shifts, } if isinstance(model, BagOfModels): estimates = 0 totals = [0] * len(model.sources) bag_num = len(model.models) fut_length = 0 prog_bar = 0 current_model = 0 for sub_model, weight in zip(model.models, model.weights): original_model_device = next(iter(sub_model.parameters())).device sub_model.to(device) fut_length += fut_length current_model += 1 out = apply_model(sub_model, mix, **kwargs) sub_model.to(original_model_device) for k, inst_weight in enumerate(weight): out[:, k, :, :] *= inst_weight totals[k] += inst_weight estimates += out del out for k in range(estimates.shape[1]): estimates[:, k, :, :] /= totals[k] return estimates model.to(device) model.eval() assert transition_power >= 1 batch, channels, length = mix.shape if shifts: kwargs["shifts"] = 0 max_shift = int(0.5 * model.samplerate) mix = tensor_chunk(mix) padded_mix = mix.padded(length + 2 * max_shift) out = 0 for _ in range(shifts): offset = random.randint(0, max_shift) shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) shifted_out = apply_model(model, shifted, **kwargs) out += shifted_out[..., max_shift - offset :] out /= shifts return out elif split: kwargs["split"] = False out = torch.zeros(batch, len(model.sources), channels, length, device=mix.device) sum_weight = torch.zeros(length, device=mix.device) segment = int(model.samplerate * model.segment) stride = int((1 - overlap) * segment) offsets = range(0, length, stride) weight = torch.cat([torch.arange(1, segment // 2 + 1, device=device), torch.arange(segment - segment // 2, 0, -1, device=device)]) assert len(weight) == segment weight = (weight / weight.max()) ** transition_power futures = [] for offset in offsets: chunk = TensorChunk(mix, offset, segment) future = pool.submit(apply_model, model, chunk, **kwargs) futures.append((future, offset)) offset += segment if progress: futures = tqdm.tqdm(futures) for future, offset in futures: if set_progress_bar: fut_length = len(futures) * bag_num * static_shifts prog_bar += 1 set_progress_bar(0.1, (0.8 / fut_length * prog_bar)) chunk_out = future.result() chunk_length = chunk_out.shape[-1] out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device) assert sum_weight.min() > 0 out /= sum_weight return out else: valid_length = model.valid_length(length) if hasattr(model, "valid_length") else length mix = tensor_chunk(mix) padded_mix = mix.padded(valid_length).to(device) with torch.no_grad(): out = model(padded_mix) return center_trim(out, length) def demucs_segments(demucs_segment, demucs_model): if demucs_segment == "Default": segment = None if isinstance(demucs_model, BagOfModels): if segment is not None: for sub in demucs_model.models: sub.segment = segment else: if segment is not None: sub.segment = segment else: try: segment = int(demucs_segment) if isinstance(demucs_model, BagOfModels): if segment is not None: for sub in demucs_model.models: sub.segment = segment else: if segment is not None: sub.segment = segment except: segment = None if isinstance(demucs_model, BagOfModels): if segment is not None: for sub in demucs_model.models: sub.segment = segment else: if segment is not None: sub.segment = segment return demucs_model