File size: 4,864 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import sys
import glob
import torch
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from collections import OrderedDict
sys.path.append(os.getcwd())
from main.app.variables import translations
MATPLOTLIB_FLAG = False
def replace_keys_in_dict(d, old_key_part, new_key_part):
updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
for key, value in d.items():
updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
return updated_dict
def load_checkpoint(logger, checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
model.module.load_state_dict(new_state_dict, strict=False) if hasattr(model, "module") else model.load_state_dict(new_state_dict, strict=False)
if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
def save_checkpoint(logger, model, optimizer, learning_rate, iteration, checkpoint_path):
state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sample_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
return checkpoints[-1] if checkpoints else None
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
plt.switch_backend("Agg")
MATPLOTLIB_FLAG = True
fig, ax = plt.subplots(figsize=(10, 2))
plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
plt.close(fig)
try:
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
except:
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="").reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def load_wav_to_torch(full_path):
data, sample_rate = sf.read(full_path, dtype=np.float32)
return torch.FloatTensor(data.astype(np.float32)), sample_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
return [line.strip().split(split) for line in f]
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
self[k] = HParams(**v) if isinstance(v, dict) else v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return self.__dict__[key]
def __setitem__(self, key, value):
self.__dict__[key] = value
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return repr(self.__dict__) |