|
import tqdm |
|
import torch |
|
import random |
|
|
|
import typing as tp |
|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
from .demucs import Demucs |
|
from .hdemucs import HDemucs |
|
from .htdemucs import HTDemucs |
|
from .utils import center_trim |
|
|
|
|
|
Model = tp.Union[Demucs, HDemucs, HTDemucs] |
|
|
|
class DummyPoolExecutor: |
|
class DummyResult: |
|
def __init__(self, func, *args, **kwargs): |
|
self.func = func |
|
self.args = args |
|
self.kwargs = kwargs |
|
|
|
def result(self): |
|
return self.func(*self.args, **self.kwargs) |
|
|
|
def __init__(self, workers=0): |
|
pass |
|
|
|
def submit(self, func, *args, **kwargs): |
|
return DummyPoolExecutor.DummyResult(func, *args, **kwargs) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb): |
|
return |
|
|
|
class BagOfModels(nn.Module): |
|
def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None): |
|
super().__init__() |
|
assert len(models) > 0 |
|
first = models[0] |
|
|
|
for other in models: |
|
assert other.sources == first.sources |
|
assert other.samplerate == first.samplerate |
|
assert other.audio_channels == first.audio_channels |
|
|
|
if segment is not None: other.segment = segment |
|
|
|
self.audio_channels = first.audio_channels |
|
self.samplerate = first.samplerate |
|
self.sources = first.sources |
|
self.models = nn.ModuleList(models) |
|
|
|
if weights is None: weights = [[1.0 for _ in first.sources] for _ in models] |
|
else: |
|
assert len(weights) == len(models) |
|
|
|
for weight in weights: |
|
assert len(weight) == len(first.sources) |
|
|
|
self.weights = weights |
|
|
|
def forward(self, x): |
|
raise NotImplementedError("`apply_model`") |
|
|
|
|
|
class TensorChunk: |
|
def __init__(self, tensor, offset=0, length=None): |
|
total_length = tensor.shape[-1] |
|
assert offset >= 0 |
|
assert offset < total_length |
|
|
|
length = total_length - offset if length is None else min(total_length - offset, length) |
|
|
|
if isinstance(tensor, TensorChunk): |
|
self.tensor = tensor.tensor |
|
self.offset = offset + tensor.offset |
|
else: |
|
self.tensor = tensor |
|
self.offset = offset |
|
|
|
self.length = length |
|
self.device = tensor.device |
|
|
|
@property |
|
def shape(self): |
|
shape = list(self.tensor.shape) |
|
shape[-1] = self.length |
|
|
|
return shape |
|
|
|
def padded(self, target_length): |
|
delta = target_length - self.length |
|
total_length = self.tensor.shape[-1] |
|
assert delta >= 0 |
|
|
|
start = self.offset - delta // 2 |
|
end = start + target_length |
|
|
|
correct_start = max(0, start) |
|
correct_end = min(total_length, end) |
|
|
|
pad_left = correct_start - start |
|
pad_right = end - correct_end |
|
|
|
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) |
|
|
|
assert out.shape[-1] == target_length |
|
|
|
return out |
|
|
|
|
|
def tensor_chunk(tensor_or_chunk): |
|
if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk |
|
else: |
|
assert isinstance(tensor_or_chunk, torch.Tensor) |
|
|
|
return TensorChunk(tensor_or_chunk) |
|
|
|
|
|
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None): |
|
global fut_length |
|
global bag_num |
|
global prog_bar |
|
|
|
device = mix.device if device is None else torch.device(device) |
|
|
|
if pool is None: pool = ThreadPoolExecutor(num_workers) if num_workers > 0 and device.type == "cpu" else DummyPoolExecutor() |
|
|
|
kwargs = { |
|
"shifts": shifts, |
|
"split": split, |
|
"overlap": overlap, |
|
"transition_power": transition_power, |
|
"progress": progress, |
|
"device": device, |
|
"pool": pool, |
|
"set_progress_bar": set_progress_bar, |
|
"static_shifts": static_shifts, |
|
} |
|
|
|
if isinstance(model, BagOfModels): |
|
estimates = 0 |
|
totals = [0] * len(model.sources) |
|
bag_num = len(model.models) |
|
fut_length = 0 |
|
prog_bar = 0 |
|
current_model = 0 |
|
|
|
for sub_model, weight in zip(model.models, model.weights): |
|
original_model_device = next(iter(sub_model.parameters())).device |
|
sub_model.to(device) |
|
fut_length += fut_length |
|
current_model += 1 |
|
|
|
out = apply_model(sub_model, mix, **kwargs) |
|
sub_model.to(original_model_device) |
|
|
|
for k, inst_weight in enumerate(weight): |
|
out[:, k, :, :] *= inst_weight |
|
totals[k] += inst_weight |
|
|
|
estimates += out |
|
del out |
|
|
|
for k in range(estimates.shape[1]): |
|
estimates[:, k, :, :] /= totals[k] |
|
|
|
return estimates |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
assert transition_power >= 1 |
|
|
|
batch, channels, length = mix.shape |
|
|
|
if shifts: |
|
kwargs["shifts"] = 0 |
|
max_shift = int(0.5 * model.samplerate) |
|
mix = tensor_chunk(mix) |
|
padded_mix = mix.padded(length + 2 * max_shift) |
|
out = 0 |
|
|
|
for _ in range(shifts): |
|
offset = random.randint(0, max_shift) |
|
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) |
|
shifted_out = apply_model(model, shifted, **kwargs) |
|
out += shifted_out[..., max_shift - offset :] |
|
|
|
out /= shifts |
|
|
|
return out |
|
elif split: |
|
kwargs["split"] = False |
|
out = torch.zeros(batch, len(model.sources), channels, length, device=mix.device) |
|
sum_weight = torch.zeros(length, device=mix.device) |
|
segment = int(model.samplerate * model.segment) |
|
stride = int((1 - overlap) * segment) |
|
offsets = range(0, length, stride) |
|
|
|
weight = torch.cat([torch.arange(1, segment // 2 + 1, device=device), torch.arange(segment - segment // 2, 0, -1, device=device)]) |
|
assert len(weight) == segment |
|
|
|
weight = (weight / weight.max()) ** transition_power |
|
futures = [] |
|
|
|
for offset in offsets: |
|
chunk = TensorChunk(mix, offset, segment) |
|
future = pool.submit(apply_model, model, chunk, **kwargs) |
|
futures.append((future, offset)) |
|
offset += segment |
|
|
|
if progress: futures = tqdm.tqdm(futures) |
|
|
|
for future, offset in futures: |
|
if set_progress_bar: |
|
fut_length = len(futures) * bag_num * static_shifts |
|
prog_bar += 1 |
|
set_progress_bar(0.1, (0.8 / fut_length * prog_bar)) |
|
|
|
chunk_out = future.result() |
|
chunk_length = chunk_out.shape[-1] |
|
|
|
out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) |
|
sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device) |
|
|
|
assert sum_weight.min() > 0 |
|
|
|
out /= sum_weight |
|
|
|
return out |
|
else: |
|
valid_length = model.valid_length(length) if hasattr(model, "valid_length") else length |
|
|
|
mix = tensor_chunk(mix) |
|
padded_mix = mix.padded(valid_length).to(device) |
|
|
|
with torch.no_grad(): |
|
out = model(padded_mix) |
|
|
|
return center_trim(out, length) |
|
|
|
|
|
def demucs_segments(demucs_segment, demucs_model): |
|
if demucs_segment == "Default": |
|
segment = None |
|
|
|
if isinstance(demucs_model, BagOfModels): |
|
if segment is not None: |
|
for sub in demucs_model.models: |
|
sub.segment = segment |
|
else: |
|
if segment is not None: sub.segment = segment |
|
else: |
|
try: |
|
segment = int(demucs_segment) |
|
|
|
if isinstance(demucs_model, BagOfModels): |
|
if segment is not None: |
|
for sub in demucs_model.models: |
|
sub.segment = segment |
|
else: |
|
if segment is not None: sub.segment = segment |
|
except: |
|
segment = None |
|
|
|
if isinstance(demucs_model, BagOfModels): |
|
if segment is not None: |
|
for sub in demucs_model.models: |
|
sub.segment = segment |
|
else: |
|
if segment is not None: sub.segment = segment |
|
|
|
return demucs_model |