|
import os
|
|
import sys
|
|
import math
|
|
import torch
|
|
import librosa
|
|
import torchaudio
|
|
|
|
import numpy as np
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.predictors.FCN.model import MODEL
|
|
from main.library.predictors.FCN.convert import frequency_to_bins, seconds_to_samples, bins_to_frequency
|
|
|
|
CENTS_PER_BIN, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 5, 1440, 16000, 1024
|
|
|
|
class FCN:
|
|
def __init__(self, model_path, hop_length=160, batch_size=None, f0_min=50, f0_max=1100, device=None, sample_rate=16000, providers=None, onnx=False):
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.hopsize = hop_length / SAMPLE_RATE
|
|
self.batch_size = batch_size
|
|
self.sample_rate = sample_rate
|
|
self.onnx = onnx
|
|
self.f0_min = f0_min
|
|
self.f0_max = f0_max
|
|
|
|
if self.onnx:
|
|
import onnxruntime as ort
|
|
|
|
sess_options = ort.SessionOptions()
|
|
sess_options.log_severity_level = 3
|
|
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
|
else:
|
|
model = MODEL()
|
|
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
|
model.load_state_dict(ckpt['model'])
|
|
model.eval()
|
|
self.model = model.to(device)
|
|
|
|
def entropy(self, logits):
|
|
distribution = torch.nn.functional.softmax(logits, dim=1)
|
|
return (1 + 1 / math.log(PITCH_BINS) * (distribution * torch.log(distribution + 1e-7)).sum(dim=1))
|
|
|
|
def expected_frames(self, samples, center):
|
|
hopsize_resampled = seconds_to_samples(self.hopsize, self.sample_rate)
|
|
|
|
if center == 'half-window':
|
|
window_size_resampled = WINDOW_SIZE / SAMPLE_RATE * self.sample_rate
|
|
samples = samples - (window_size_resampled - hopsize_resampled)
|
|
elif center == 'half-hop':
|
|
samples = samples
|
|
elif center == 'zero':
|
|
samples = samples + hopsize_resampled
|
|
else: raise ValueError
|
|
|
|
return max(1, int(samples / hopsize_resampled))
|
|
|
|
def resample(self, audio, target_rate=SAMPLE_RATE):
|
|
if self.sample_rate == target_rate: return audio
|
|
|
|
resampler = torchaudio.transforms.Resample(self.sample_rate, target_rate)
|
|
resampler = resampler.to(audio.device)
|
|
|
|
return resampler(audio)
|
|
|
|
def preprocess(self, audio, center='half-window'):
|
|
total_frames = self.expected_frames(audio.shape[-1], center)
|
|
if self.sample_rate != SAMPLE_RATE: audio = self.resample(audio)
|
|
|
|
hopsize = seconds_to_samples(self.hopsize, SAMPLE_RATE)
|
|
if center in ['half-hop', 'zero']:
|
|
if center == 'half-hop': padding = int((WINDOW_SIZE - hopsize) / 2)
|
|
else: padding = int(WINDOW_SIZE / 2)
|
|
|
|
audio = torch.nn.functional.pad(audio, (padding, padding), mode='reflect')
|
|
|
|
if isinstance(hopsize, int) or hopsize.is_integer():
|
|
hopsize = int(round(hopsize))
|
|
start_idxs = None
|
|
else:
|
|
start_idxs = torch.round(torch.tensor([hopsize * i for i in range(total_frames + 1)])).int()
|
|
|
|
batch_size = total_frames if self.batch_size is None else self.batch_size
|
|
|
|
for i in range(0, total_frames, batch_size):
|
|
batch = min(total_frames - i, batch_size)
|
|
|
|
if start_idxs is None:
|
|
start = i * hopsize
|
|
end = start + int((batch - 1) * hopsize) + WINDOW_SIZE
|
|
end = min(end, audio.shape[-1])
|
|
batch_audio = audio[:, start:end]
|
|
|
|
if end - start < WINDOW_SIZE:
|
|
padding = WINDOW_SIZE - (end - start)
|
|
remainder = (end - start) % hopsize
|
|
|
|
if remainder: padding += end - start - hopsize
|
|
batch_audio = torch.nn.functional.pad(batch_audio, (0, padding))
|
|
|
|
frames = torch.nn.functional.unfold(batch_audio[:, None, None], kernel_size=(1, WINDOW_SIZE), stride=(1, hopsize)).permute(2, 0, 1)
|
|
else:
|
|
frames = torch.zeros(batch, 1, WINDOW_SIZE)
|
|
|
|
for j in range(batch):
|
|
start = start_idxs[i + j]
|
|
end = min(start + WINDOW_SIZE, audio.shape[-1])
|
|
frames[j, :, : end - start] = audio[:, start:end]
|
|
|
|
yield frames
|
|
|
|
def viterbi(self, logits):
|
|
if not hasattr(self, 'transition'):
|
|
xx, yy = np.meshgrid(range(PITCH_BINS), range(PITCH_BINS))
|
|
transition = np.maximum(12 - abs(xx - yy), 0)
|
|
self.transition = transition / transition.sum(axis=1, keepdims=True)
|
|
|
|
with torch.no_grad():
|
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
|
|
|
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, self.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
|
|
return bins_to_frequency(bins)
|
|
|
|
def postprocess(self, logits):
|
|
with torch.inference_mode():
|
|
minidx = frequency_to_bins(torch.tensor(self.f0_min))
|
|
maxidx = frequency_to_bins(torch.tensor(self.f0_max), torch.ceil)
|
|
|
|
logits[:, :minidx] = -float('inf')
|
|
logits[:, maxidx:] = -float('inf')
|
|
|
|
pitch = self.viterbi(logits)
|
|
periodicity = self.entropy(logits)
|
|
|
|
return pitch.T, periodicity.T
|
|
|
|
def compute_f0(self, audio, center = 'half-window'):
|
|
if self.batch_size is not None: logits = []
|
|
|
|
for frames in self.preprocess(audio, center):
|
|
if self.onnx:
|
|
inferred = torch.tensor(
|
|
self.model.run(
|
|
[self.model.get_outputs()[0].name],
|
|
{
|
|
self.model.get_inputs()[0].name: frames.cpu().numpy()
|
|
}
|
|
)[0]
|
|
).detach()
|
|
else:
|
|
with torch.no_grad():
|
|
inferred = self.model(frames.to(self.device)).detach()
|
|
|
|
logits.append(inferred)
|
|
|
|
pitch, periodicity = self.postprocess(torch.cat(logits, 0).to(self.device))
|
|
return pitch, periodicity |