|
from io import BytesIO
|
|
import os
|
|
from typing import Union, Literal, Optional
|
|
|
|
import fairseq
|
|
import faiss
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchaudio.transforms import Resample
|
|
|
|
from rvc.f0 import PM, Harvest, RMVPE, CRePE, Dio, FCPE
|
|
from rvc.synthesizer import load_synthesizer
|
|
|
|
|
|
class RVC:
|
|
def __init__(
|
|
self,
|
|
key: Union[int, float],
|
|
formant: Union[int, float],
|
|
pth_path: torch.serialization.FILE_LIKE,
|
|
index_path: str,
|
|
index_rate: Union[int, float],
|
|
n_cpu: int = os.cpu_count(),
|
|
device: str = "cpu",
|
|
use_jit: bool = False,
|
|
is_half: bool = False,
|
|
is_dml: bool = False,
|
|
) -> None:
|
|
if is_dml:
|
|
|
|
def forward_dml(ctx, x, scale):
|
|
ctx.scale = scale
|
|
res = x.clone().detach()
|
|
return res
|
|
|
|
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
|
|
|
|
self.device = device
|
|
self.f0_up_key = key
|
|
self.formant_shift = formant
|
|
self.sr = 16000
|
|
self.window = 160
|
|
self.f0_min = 50
|
|
self.f0_max = 1100
|
|
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
|
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
|
self.n_cpu = n_cpu
|
|
self.use_jit = use_jit
|
|
self.is_half = is_half
|
|
|
|
if index_rate > 0:
|
|
self.index = faiss.read_index(index_path)
|
|
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
|
|
|
|
self.pth_path = pth_path
|
|
self.index_path = index_path
|
|
self.index_rate = index_rate
|
|
|
|
self.cache_pitch: torch.Tensor = torch.zeros(
|
|
1024, device=self.device, dtype=torch.long
|
|
)
|
|
self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32)
|
|
|
|
self.resample_kernel = {}
|
|
|
|
self.f0_methods = {
|
|
"crepe": self._get_f0_crepe,
|
|
"rmvpe": self._get_f0_rmvpe,
|
|
"fcpe": self._get_f0_fcpe,
|
|
"pm": self._get_f0_pm,
|
|
"harvest": self._get_f0_harvest,
|
|
"dio": self._get_f0_dio,
|
|
}
|
|
|
|
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
|
["assets/hubert/hubert_base.pt"],
|
|
suffix="",
|
|
)
|
|
hubert_model = models[0]
|
|
hubert_model = hubert_model.to(self.device)
|
|
if self.is_half:
|
|
hubert_model = hubert_model.half()
|
|
else:
|
|
hubert_model = hubert_model.float()
|
|
hubert_model.eval()
|
|
self.hubert = hubert_model
|
|
|
|
self.net_g: Optional[nn.Module] = None
|
|
|
|
def set_default_model():
|
|
self.net_g, cpt = load_synthesizer(self.pth_path, self.device)
|
|
self.tgt_sr = cpt["config"][-1]
|
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
|
self.if_f0 = cpt.get("f0", 1)
|
|
self.version = cpt.get("version", "v1")
|
|
if self.is_half:
|
|
self.net_g = self.net_g.half()
|
|
else:
|
|
self.net_g = self.net_g.float()
|
|
|
|
def set_jit_model():
|
|
from rvc.jit import get_jit_model
|
|
from rvc.synthesizer import synthesizer_jit_export
|
|
|
|
cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export)
|
|
|
|
self.tgt_sr = cpt["config"][-1]
|
|
self.if_f0 = cpt.get("f0", 1)
|
|
self.version = cpt.get("version", "v1")
|
|
self.net_g = torch.jit.load(BytesIO(cpt["model"]), map_location=self.device)
|
|
self.net_g.infer = self.net_g.forward
|
|
self.net_g.eval().to(self.device)
|
|
|
|
if (
|
|
self.use_jit
|
|
and not is_dml
|
|
and not (self.is_half and "cpu" in str(self.device))
|
|
):
|
|
set_jit_model()
|
|
else:
|
|
set_default_model()
|
|
|
|
def set_key(self, new_key):
|
|
self.f0_up_key = new_key
|
|
|
|
def set_formant(self, new_formant):
|
|
self.formant_shift = new_formant
|
|
|
|
def set_index_rate(self, new_index_rate):
|
|
if new_index_rate > 0 and self.index_rate <= 0:
|
|
self.index = faiss.read_index(self.index_path)
|
|
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
|
|
self.index_rate = new_index_rate
|
|
|
|
def infer(
|
|
self,
|
|
input_wav: torch.Tensor,
|
|
block_frame_16k: int,
|
|
skip_head: int,
|
|
return_length: int,
|
|
f0method: Union[tuple, str],
|
|
inp_f0: Optional[np.ndarray] = None,
|
|
protect: float = 1.0,
|
|
) -> np.ndarray:
|
|
with torch.no_grad():
|
|
if self.is_half:
|
|
feats = input_wav.half()
|
|
else:
|
|
feats = input_wav.float()
|
|
feats = feats.to(self.device)
|
|
if feats.dim() == 2:
|
|
feats = feats.mean(-1)
|
|
feats = feats.view(1, -1)
|
|
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
|
|
|
inputs = {
|
|
"source": feats,
|
|
"padding_mask": padding_mask,
|
|
"output_layer": 9 if self.version == "v1" else 12,
|
|
}
|
|
logits = self.hubert.extract_features(**inputs)
|
|
feats = (
|
|
self.hubert.final_proj(logits[0]) if self.version == "v1" else logits[0]
|
|
)
|
|
feats = torch.cat((feats, feats[:, -1:, :]), 1)
|
|
if protect < 0.5 and self.if_f0 == 1:
|
|
feats0 = feats.clone()
|
|
|
|
try:
|
|
if hasattr(self, "index") and self.index_rate > 0:
|
|
npy = feats[0][skip_head // 2 :].cpu().numpy()
|
|
if self.is_half:
|
|
npy = npy.astype("float32")
|
|
score, ix = self.index.search(npy, k=8)
|
|
if (ix >= 0).all():
|
|
weight = np.square(1 / score)
|
|
weight /= weight.sum(axis=1, keepdims=True)
|
|
npy = np.sum(
|
|
self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1
|
|
)
|
|
if self.is_half:
|
|
npy = npy.astype("float16")
|
|
feats[0][skip_head // 2 :] = (
|
|
torch.from_numpy(npy).unsqueeze(0).to(self.device)
|
|
* self.index_rate
|
|
+ (1 - self.index_rate) * feats[0][skip_head // 2 :]
|
|
)
|
|
except:
|
|
pass
|
|
|
|
p_len = input_wav.shape[0] // self.window
|
|
factor = pow(2, self.formant_shift / 12)
|
|
return_length2 = int(np.ceil(return_length * factor))
|
|
cache_pitch = cache_pitchf = None
|
|
pitch = pitchf = None
|
|
if isinstance(f0method, tuple):
|
|
pitch, pitchf = f0method
|
|
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
|
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
|
elif self.if_f0 == 1:
|
|
f0_extractor_frame = block_frame_16k + 800
|
|
if f0method == "rmvpe":
|
|
f0_extractor_frame = (
|
|
5120 * ((f0_extractor_frame - 1) // 5120 + 1) - self.window
|
|
)
|
|
if inp_f0 is not None:
|
|
pitch, pitchf = self._get_f0_post(
|
|
inp_f0, self.f0_up_key - self.formant_shift
|
|
)
|
|
else:
|
|
pitch, pitchf = self._get_f0(
|
|
input_wav[-f0_extractor_frame:],
|
|
self.f0_up_key - self.formant_shift,
|
|
method=f0method,
|
|
)
|
|
shift = block_frame_16k // self.window
|
|
self.cache_pitch[:-shift] = self.cache_pitch[shift:].clone()
|
|
self.cache_pitchf[:-shift] = self.cache_pitchf[shift:].clone()
|
|
self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1]
|
|
self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1]
|
|
cache_pitch = self.cache_pitch[None, -p_len:]
|
|
cache_pitchf = (
|
|
self.cache_pitchf[None, -p_len:] * return_length2 / return_length
|
|
)
|
|
|
|
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
|
feats = feats[:, :p_len, :]
|
|
if protect < 0.5 and pitch is not None and pitchf is not None:
|
|
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
|
|
0, 2, 1
|
|
)
|
|
feats0 = feats0[:, :p_len, :]
|
|
pitchff = pitchf.clone()
|
|
pitchff[pitchf > 0] = 1
|
|
pitchff[pitchf < 1] = protect
|
|
pitchff = pitchff.unsqueeze(-1)
|
|
feats = feats * pitchff + feats0 * (1 - pitchff)
|
|
feats = feats.to(feats0.dtype)
|
|
p_len = torch.LongTensor([p_len]).to(self.device)
|
|
sid = torch.LongTensor([0]).to(self.device)
|
|
with torch.no_grad():
|
|
infered_audio = (
|
|
self.net_g.infer(
|
|
feats,
|
|
p_len,
|
|
sid,
|
|
pitch=cache_pitch,
|
|
pitchf=cache_pitchf,
|
|
skip_head=skip_head,
|
|
return_length=return_length,
|
|
return_length2=return_length2,
|
|
)
|
|
.squeeze(1)
|
|
.float()
|
|
)
|
|
upp_res = int(np.floor(factor * self.tgt_sr // 100))
|
|
if upp_res != self.tgt_sr // 100:
|
|
if upp_res not in self.resample_kernel:
|
|
self.resample_kernel[upp_res] = Resample(
|
|
orig_freq=upp_res,
|
|
new_freq=self.tgt_sr // 100,
|
|
dtype=torch.float32,
|
|
).to(self.device)
|
|
infered_audio = self.resample_kernel[upp_res](
|
|
infered_audio[:, : return_length * upp_res]
|
|
)
|
|
return infered_audio.squeeze()
|
|
|
|
def _get_f0(
|
|
self,
|
|
x: torch.Tensor,
|
|
f0_up_key: Union[int, float],
|
|
filter_radius: Optional[Union[int, float]] = None,
|
|
method: Literal["crepe", "rmvpe", "fcpe", "pm", "harvest", "dio"] = "fcpe",
|
|
):
|
|
if method not in self.f0_methods.keys():
|
|
raise RuntimeError("Not supported f0 method: " + method)
|
|
return self.f0_methods[method](x, f0_up_key, filter_radius)
|
|
|
|
def _get_f0_post(self, f0, f0_up_key):
|
|
f0 *= pow(2, f0_up_key / 12)
|
|
if not torch.is_tensor(f0):
|
|
f0 = torch.from_numpy(f0)
|
|
f0 = f0.float().to(self.device).squeeze()
|
|
f0_mel = 1127 * torch.log(1 + f0 / 700)
|
|
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (
|
|
self.f0_mel_max - self.f0_mel_min
|
|
) + 1
|
|
f0_mel[f0_mel <= 1] = 1
|
|
f0_mel[f0_mel > 255] = 255
|
|
f0_coarse = torch.round(f0_mel).long()
|
|
return f0_coarse, f0
|
|
|
|
def _get_f0_pm(self, x, f0_up_key, filter_radius):
|
|
if not hasattr(self, "pm"):
|
|
self.pm = PM(hop_length=160, sampling_rate=16000)
|
|
f0 = self.pm.compute_f0(x.cpu().numpy())
|
|
return self._get_f0_post(f0, f0_up_key)
|
|
|
|
def _get_f0_harvest(self, x, f0_up_key, filter_radius=3):
|
|
if not hasattr(self, "harvest"):
|
|
self.harvest = Harvest(
|
|
self.window,
|
|
self.f0_min,
|
|
self.f0_max,
|
|
self.sr,
|
|
)
|
|
if filter_radius is None:
|
|
filter_radius = 3
|
|
f0 = self.harvest.compute_f0(x.cpu().numpy(), filter_radius=filter_radius)
|
|
return self._get_f0_post(f0, f0_up_key)
|
|
|
|
def _get_f0_dio(self, x, f0_up_key, filter_radius):
|
|
if not hasattr(self, "dio"):
|
|
self.dio = Dio(
|
|
self.window,
|
|
self.f0_min,
|
|
self.f0_max,
|
|
self.sr,
|
|
)
|
|
f0 = self.dio.compute_f0(x.cpu().numpy())
|
|
return self._get_f0_post(f0, f0_up_key)
|
|
|
|
def _get_f0_crepe(self, x, f0_up_key, filter_radius):
|
|
if hasattr(self, "crepe") == False:
|
|
self.crepe = CRePE(
|
|
self.window,
|
|
self.f0_min,
|
|
self.f0_max,
|
|
self.sr,
|
|
self.device,
|
|
)
|
|
f0 = self.crepe.compute_f0(x)
|
|
return self._get_f0_post(f0, f0_up_key)
|
|
|
|
def _get_f0_rmvpe(self, x, f0_up_key, filter_radius=0.03):
|
|
if hasattr(self, "rmvpe") == False:
|
|
self.rmvpe = RMVPE(
|
|
"%s/rmvpe.pt" % os.environ["rmvpe_root"],
|
|
is_half=self.is_half,
|
|
device=self.device,
|
|
use_jit=self.use_jit,
|
|
)
|
|
if filter_radius is None:
|
|
filter_radius = 0.03
|
|
return self._get_f0_post(
|
|
self.rmvpe.compute_f0(x, filter_radius=filter_radius),
|
|
f0_up_key,
|
|
)
|
|
|
|
def _get_f0_fcpe(self, x, f0_up_key, filter_radius):
|
|
if hasattr(self, "fcpe") == False:
|
|
self.fcpe = FCPE(
|
|
160,
|
|
self.f0_min,
|
|
self.f0_max,
|
|
16000,
|
|
self.device,
|
|
)
|
|
f0 = self.fcpe.compute_f0(x)
|
|
return self._get_f0_post(f0, f0_up_key)
|
|
|