|
import os
|
|
import sys
|
|
import glob
|
|
import torch
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import matplotlib.pyplot as plt
|
|
|
|
from collections import OrderedDict
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.app.variables import translations
|
|
|
|
MATPLOTLIB_FLAG = False
|
|
|
|
def replace_keys_in_dict(d, old_key_part, new_key_part):
|
|
updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
|
|
|
|
for key, value in d.items():
|
|
updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
|
|
|
|
return updated_dict
|
|
|
|
def load_checkpoint(logger, checkpoint_path, model, optimizer=None, load_opt=1):
|
|
assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
|
|
|
|
checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
|
|
new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
|
|
model.module.load_state_dict(new_state_dict, strict=False) if hasattr(model, "module") else model.load_state_dict(new_state_dict, strict=False)
|
|
|
|
if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
|
|
logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
|
|
|
|
return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
|
|
|
|
def save_checkpoint(logger, model, optimizer, learning_rate, iteration, checkpoint_path):
|
|
state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
|
|
torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
|
|
logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
|
|
|
|
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
|
|
for k, v in scalars.items():
|
|
writer.add_scalar(k, v, global_step)
|
|
|
|
for k, v in histograms.items():
|
|
writer.add_histogram(k, v, global_step)
|
|
|
|
for k, v in images.items():
|
|
writer.add_image(k, v, global_step, dataformats="HWC")
|
|
|
|
for k, v in audios.items():
|
|
writer.add_audio(k, v, global_step, audio_sample_rate)
|
|
|
|
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
|
checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
|
|
return checkpoints[-1] if checkpoints else None
|
|
|
|
def plot_spectrogram_to_numpy(spectrogram):
|
|
global MATPLOTLIB_FLAG
|
|
|
|
if not MATPLOTLIB_FLAG:
|
|
plt.switch_backend("Agg")
|
|
MATPLOTLIB_FLAG = True
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 2))
|
|
plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
|
|
plt.xlabel("Frames")
|
|
plt.ylabel("Channels")
|
|
plt.tight_layout()
|
|
fig.canvas.draw()
|
|
plt.close(fig)
|
|
|
|
try:
|
|
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
|
|
except:
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="").reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
|
return data
|
|
|
|
def load_wav_to_torch(full_path):
|
|
data, sample_rate = sf.read(full_path, dtype=np.float32)
|
|
return torch.FloatTensor(data.astype(np.float32)), sample_rate
|
|
|
|
def load_filepaths_and_text(filename, split="|"):
|
|
with open(filename, encoding="utf-8") as f:
|
|
return [line.strip().split(split) for line in f]
|
|
|
|
class HParams:
|
|
def __init__(self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
self[k] = HParams(**v) if isinstance(v, dict) else v
|
|
|
|
def keys(self):
|
|
return self.__dict__.keys()
|
|
|
|
def items(self):
|
|
return self.__dict__.items()
|
|
|
|
def values(self):
|
|
return self.__dict__.values()
|
|
|
|
def __len__(self):
|
|
return len(self.__dict__)
|
|
|
|
def __getitem__(self, key):
|
|
return self.__dict__[key]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.__dict__[key] = value
|
|
|
|
def __contains__(self, key):
|
|
return key in self.__dict__
|
|
|
|
def __repr__(self):
|
|
return repr(self.__dict__) |