Blane187's picture
Upload folder using huggingface_hub
7ff2ba3 verified
raw
history blame
5.34 kB
import numpy as np
import torch
import hashlib
import pathlib
from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
from configs import CPUConfig, singleton_variable
from rvc.synthesizer import get_synthesizer
from .pipeline import Pipeline
from .utils import load_hubert
class TorchSeedContext:
def __init__(self, seed):
self.seed = seed
self.state = None
def __enter__(self):
self.state = torch.random.get_rng_state()
torch.manual_seed(self.seed)
def __exit__(self, type, value, traceback):
torch.random.set_rng_state(self.state)
half_hash_len = 512
expand_factor = 65536 * 8
@singleton_variable
def original_audio_storage():
return np.load(pathlib.Path(__file__).parent / "lgdsng.npz")
@singleton_variable
def original_audio():
return original_audio_storage()["a"]
@singleton_variable
def original_audio_time_minus():
return original_audio_storage()["t"]
@singleton_variable
def original_audio_freq_minus():
return original_audio_storage()["f"]
@singleton_variable
def original_rmvpe_f0():
x = original_audio_storage()
return x["pitch"], x["pitchf"]
def _cut_u16(n):
if n > 16384:
n = 16384 + 16384 * (1 - np.exp((16384 - n) / expand_factor))
elif n < -16384:
n = -16384 - 16384 * (1 - np.exp((n + 16384) / expand_factor))
return n
# wave_hash will change time_field, use carefully
def wave_hash(time_field):
np.divide(time_field, np.abs(time_field).max(), time_field)
if len(time_field) != 48000:
raise Exception("time not hashable")
freq_field = fft(time_field)
if len(freq_field) != 48000:
raise Exception("freq not hashable")
np.add(time_field, original_audio_time_minus(), out=time_field)
np.add(freq_field, original_audio_freq_minus(), out=freq_field)
hash = np.zeros(half_hash_len // 2 * 2, dtype=">i2")
d = 375 * 512 // half_hash_len
for i in range(half_hash_len // 4):
a = i * 2
b = a + 1
x = a + half_hash_len // 2
y = x + 1
s = np.average(freq_field[i * d : (i + 1) * d])
hash[a] = np.int16(_cut_u16(round(32768 * np.real(s))))
hash[b] = np.int16(_cut_u16(round(32768 * np.imag(s))))
hash[x] = np.int16(
_cut_u16(round(32768 * np.sum(time_field[i * d : i * d + d // 2])))
)
hash[y] = np.int16(
_cut_u16(round(32768 * np.sum(time_field[i * d + d // 2 : (i + 1) * d])))
)
return encode_to_string(hash.tobytes())
def model_hash(config, tgt_sr, net_g, if_f0, version):
pipeline = Pipeline(tgt_sr, config)
audio = original_audio()
hbt = load_hubert(config.device, config.is_half)
audio_opt = pipeline.pipeline(
hbt,
net_g,
0,
audio,
[0, 0, 0],
6,
original_rmvpe_f0(),
"",
0,
2 if if_f0 else 0,
3,
tgt_sr,
16000,
0.25,
version,
0.33,
)
del hbt
opt_len = len(audio_opt)
diff = 48000 - opt_len
if diff > 0:
audio_opt = np.pad(audio_opt, (diff, 0))
elif diff < 0:
n = diff // 2
n = -n
audio_opt = audio_opt[n:-n]
h = wave_hash(audio_opt)
del pipeline, audio_opt
return h
def model_hash_ckpt(cpt):
config = CPUConfig()
with TorchSeedContext(114514):
net_g, cpt = get_synthesizer(cpt, config.device)
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
if config.is_half:
net_g = net_g.half()
else:
net_g = net_g.float()
h = model_hash(config, tgt_sr, net_g, if_f0, version)
del net_g
return h
def model_hash_from(path):
cpt = torch.load(path, map_location="cpu")
h = model_hash_ckpt(cpt)
del cpt
return h
def _extend_difference(n, a, b):
if n < a:
n = a
elif n > b:
n = b
n -= a
n /= b - a
return n
def hash_similarity(h1: str, h2: str) -> float:
try:
h1b, h2b = decode_from_string(h1), decode_from_string(h2)
if len(h1b) != half_hash_len * 2 or len(h2b) != half_hash_len * 2:
raise Exception("invalid hash length")
h1n, h2n = np.frombuffer(h1b, dtype=">i2"), np.frombuffer(h2b, dtype=">i2")
d = 0
for i in range(half_hash_len // 4):
a = i * 2
b = a + 1
ax = complex(h1n[a], h1n[b])
bx = complex(h2n[a], h2n[b])
if abs(ax) == 0 or abs(bx) == 0:
continue
d += np.abs(ax - bx)
frac = np.linalg.norm(h1n) * np.linalg.norm(h2n)
cosine = (
np.dot(h1n.astype(np.float32), h2n.astype(np.float32)) / frac
if frac != 0
else 1.0
)
distance = _extend_difference(np.exp(-d / expand_factor), 0.5, 1.0)
return round((abs(cosine) + distance) / 2, 6)
except Exception as e:
return str(e)
def hash_id(h: str) -> str:
d = decode_from_string(h)
if len(d) != half_hash_len * 2:
return "invalid hash length"
return encode_to_string(
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])