import os import sys import torch import librosa import scipy.stats import numpy as np sys.path.append(os.getcwd()) from main.library.predictors.CREPE.model import MODEL CENTS_PER_BIN, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 360, 16000, 1024 class CREPE: def __init__(self, model_path, model_size="full", hop_length=512, batch_size=None, f0_min=50, f0_max=1100, device=None, sample_rate=16000, providers=None, onnx=False, return_periodicity=False): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.hop_length = hop_length self.batch_size = batch_size self.sample_rate = sample_rate self.onnx = onnx self.f0_min = f0_min self.f0_max = f0_max self.return_periodicity = return_periodicity if self.onnx: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.log_severity_level = 3 self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers) else: model = MODEL(model_size) ckpt = torch.load(model_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt) model.eval() self.model = model.to(device) def bins_to_frequency(self, bins): if str(bins.device).startswith("ocl"): bins = bins.to(torch.float32) cents = CENTS_PER_BIN * bins + 1997.3794084376191 return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200) def frequency_to_bins(self, frequency, quantize_fn=torch.floor): return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int() def viterbi(self, logits): if not hasattr(self, 'transition'): xx, yy = np.meshgrid(range(360), range(360)) transition = np.maximum(12 - abs(xx - yy), 0) self.transition = transition / transition.sum(axis=1, keepdims=True) with torch.no_grad(): probs = torch.nn.functional.softmax(logits, dim=1) bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, self.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device) return bins, self.bins_to_frequency(bins) def preprocess(self, audio, pad=True): hop_length = (self.sample_rate // 100) if self.hop_length is None else self.hop_length if self.sample_rate != SAMPLE_RATE: audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=self.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0) hop_length = int(hop_length * SAMPLE_RATE / self.sample_rate) if pad: total_frames = 1 + int(audio.size(1) // hop_length) audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2)) else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length) batch_size = total_frames if self.batch_size is None else self.batch_size for i in range(0, total_frames, batch_size): frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length)) if self.device.startswith("ocl"): frames = frames.transpose(1, 2).contiguous().reshape(-1, WINDOW_SIZE).to(self.device) else: frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(self.device) frames -= frames.mean(dim=1, keepdim=True) frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True)) yield frames def periodicity(self, probabilities, bins): probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS) periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64)) return periodicity.reshape(probabilities.size(0), probabilities.size(2)) def postprocess(self, probabilities): probabilities = probabilities.detach() probabilities[:, :self.frequency_to_bins(torch.tensor(self.f0_min))] = -float('inf') probabilities[:, self.frequency_to_bins(torch.tensor(self.f0_max), torch.ceil):] = -float('inf') bins, pitch = self.viterbi(probabilities) if not self.return_periodicity: return pitch return pitch, self.periodicity(probabilities, bins) def compute_f0(self, audio, pad=True): results = [] for frames in self.preprocess(audio, pad): if self.onnx: model = torch.tensor( self.model.run( [self.model.get_outputs()[0].name], { self.model.get_inputs()[0].name: frames.cpu().numpy() } )[0].transpose(1, 0)[None] ) else: with torch.no_grad(): model = self.model( frames, embed=False ).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2) result = self.postprocess(model) results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device)) if self.return_periodicity: pitch, periodicity = zip(*results) return torch.cat(pitch, 1), torch.cat(periodicity, 1) return torch.cat(results, 1)