saeki
commited on
Commit
·
200d40d
1
Parent(s):
61192e1
fix
Browse files- lightning_module.py +2 -85
- utils.py +0 -12
lightning_module.py
CHANGED
|
@@ -11,13 +11,6 @@ from model import (
|
|
| 11 |
MultiScaleSpectralLoss,
|
| 12 |
GSTModule,
|
| 13 |
)
|
| 14 |
-
from utils import (
|
| 15 |
-
manual_logging,
|
| 16 |
-
load_vocoder,
|
| 17 |
-
plot_and_save_mels,
|
| 18 |
-
plot_and_save_mels_all,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
|
| 22 |
class PretrainLightningModule(pl.LightningModule):
|
| 23 |
def __init__(self, config):
|
|
@@ -32,7 +25,7 @@ class PretrainLightningModule(pl.LightningModule):
|
|
| 32 |
self.channelfeats = ChannelFeatureModule(config)
|
| 33 |
|
| 34 |
self.channel = ChannelModule(config)
|
| 35 |
-
self.vocoder =
|
| 36 |
|
| 37 |
self.criteria_a = MultiScaleSpectralLoss(config)
|
| 38 |
if "feature_loss" in config["train"]:
|
|
@@ -154,8 +147,6 @@ class PretrainLightningModule(pl.LightningModule):
|
|
| 154 |
prog_bar=True,
|
| 155 |
logger=True,
|
| 156 |
)
|
| 157 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
| 158 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
| 159 |
|
| 160 |
def test_step(self, batch, batch_idx):
|
| 161 |
if self.config["general"]["use_gst"]:
|
|
@@ -224,24 +215,6 @@ class PretrainLightningModule(pl.LightningModule):
|
|
| 224 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
| 225 |
channels_first=True,
|
| 226 |
)
|
| 227 |
-
plot_and_save_mels(
|
| 228 |
-
out[key][0, ...].cpu(),
|
| 229 |
-
mel_dir / "{}-{}.png".format(idx, key),
|
| 230 |
-
self.config,
|
| 231 |
-
)
|
| 232 |
-
plot_and_save_mels_all(
|
| 233 |
-
out,
|
| 234 |
-
[
|
| 235 |
-
"reconstructed",
|
| 236 |
-
"remastered",
|
| 237 |
-
"channeled",
|
| 238 |
-
"input",
|
| 239 |
-
"input_recons",
|
| 240 |
-
"groundtruth",
|
| 241 |
-
],
|
| 242 |
-
mel_dir / "{}-all.png".format(idx),
|
| 243 |
-
self.config,
|
| 244 |
-
)
|
| 245 |
|
| 246 |
def configure_optimizers(self):
|
| 247 |
optimizer = torch.optim.Adam(
|
|
@@ -257,21 +230,6 @@ class PretrainLightningModule(pl.LightningModule):
|
|
| 257 |
}
|
| 258 |
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
| 259 |
|
| 260 |
-
def tflogger(self, logger_dict, data_type):
|
| 261 |
-
for lg in self.logger.experiment:
|
| 262 |
-
if type(lg).__name__ == "SummaryWriter":
|
| 263 |
-
tensorboard = lg
|
| 264 |
-
for key in logger_dict.keys():
|
| 265 |
-
manual_logging(
|
| 266 |
-
logger=tensorboard,
|
| 267 |
-
item=logger_dict[key],
|
| 268 |
-
idx=0,
|
| 269 |
-
tag=key,
|
| 270 |
-
global_step=self.global_step,
|
| 271 |
-
data_type=data_type,
|
| 272 |
-
config=self.config,
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
|
| 276 |
class SSLBaseModule(pl.LightningModule):
|
| 277 |
def __init__(self, config):
|
|
@@ -299,7 +257,7 @@ class SSLBaseModule(pl.LightningModule):
|
|
| 299 |
pre_model.channelfeats.state_dict(), strict=False
|
| 300 |
)
|
| 301 |
|
| 302 |
-
self.vocoder =
|
| 303 |
self.criteria = self.get_loss_function(config)
|
| 304 |
|
| 305 |
def training_step(self, batch, batch_idx):
|
|
@@ -405,32 +363,6 @@ class SSLBaseModule(pl.LightningModule):
|
|
| 405 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
| 406 |
channels_first=True,
|
| 407 |
)
|
| 408 |
-
plot_and_save_mels(
|
| 409 |
-
out[key][0, ...].cpu(),
|
| 410 |
-
mel_dir / "{}-{}.png".format(idx, key),
|
| 411 |
-
self.config,
|
| 412 |
-
)
|
| 413 |
-
plot_and_save_mels_all(
|
| 414 |
-
out,
|
| 415 |
-
plot_keys,
|
| 416 |
-
mel_dir / "{}-all.png".format(idx),
|
| 417 |
-
self.config,
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
-
def tflogger(self, logger_dict, data_type):
|
| 421 |
-
for lg in self.logger.experiment:
|
| 422 |
-
if type(lg).__name__ == "SummaryWriter":
|
| 423 |
-
tensorboard = lg
|
| 424 |
-
for key in logger_dict.keys():
|
| 425 |
-
manual_logging(
|
| 426 |
-
logger=tensorboard,
|
| 427 |
-
item=logger_dict[key],
|
| 428 |
-
idx=0,
|
| 429 |
-
tag=key,
|
| 430 |
-
global_step=self.global_step,
|
| 431 |
-
data_type=data_type,
|
| 432 |
-
config=self.config,
|
| 433 |
-
)
|
| 434 |
|
| 435 |
|
| 436 |
class SSLStepLightningModule(SSLBaseModule):
|
|
@@ -511,8 +443,6 @@ class SSLStepLightningModule(SSLBaseModule):
|
|
| 511 |
prog_bar=True,
|
| 512 |
logger=True,
|
| 513 |
)
|
| 514 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
| 515 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
| 516 |
|
| 517 |
def optimizer_step(
|
| 518 |
self,
|
|
@@ -754,8 +684,6 @@ class SSLDualLightningModule(SSLBaseModule):
|
|
| 754 |
prog_bar=True,
|
| 755 |
logger=True,
|
| 756 |
)
|
| 757 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
| 758 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
| 759 |
|
| 760 |
def test_step(self, batch, batch_idx):
|
| 761 |
if self.config["general"]["use_gst"]:
|
|
@@ -833,17 +761,6 @@ class SSLDualLightningModule(SSLBaseModule):
|
|
| 833 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
| 834 |
channels_first=True,
|
| 835 |
)
|
| 836 |
-
plot_and_save_mels(
|
| 837 |
-
out[key][0, ...].cpu(),
|
| 838 |
-
mel_dir / "{}-{}.png".format(idx, key),
|
| 839 |
-
self.config,
|
| 840 |
-
)
|
| 841 |
-
plot_and_save_mels_all(
|
| 842 |
-
out,
|
| 843 |
-
plot_keys,
|
| 844 |
-
mel_dir / "{}-all.png".format(idx),
|
| 845 |
-
self.config,
|
| 846 |
-
)
|
| 847 |
|
| 848 |
def configure_optimizers(self):
|
| 849 |
optimizer = torch.optim.Adam(
|
|
|
|
| 11 |
MultiScaleSpectralLoss,
|
| 12 |
GSTModule,
|
| 13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class PretrainLightningModule(pl.LightningModule):
|
| 16 |
def __init__(self, config):
|
|
|
|
| 25 |
self.channelfeats = ChannelFeatureModule(config)
|
| 26 |
|
| 27 |
self.channel = ChannelModule(config)
|
| 28 |
+
self.vocoder = None
|
| 29 |
|
| 30 |
self.criteria_a = MultiScaleSpectralLoss(config)
|
| 31 |
if "feature_loss" in config["train"]:
|
|
|
|
| 147 |
prog_bar=True,
|
| 148 |
logger=True,
|
| 149 |
)
|
|
|
|
|
|
|
| 150 |
|
| 151 |
def test_step(self, batch, batch_idx):
|
| 152 |
if self.config["general"]["use_gst"]:
|
|
|
|
| 215 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
| 216 |
channels_first=True,
|
| 217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
def configure_optimizers(self):
|
| 220 |
optimizer = torch.optim.Adam(
|
|
|
|
| 230 |
}
|
| 231 |
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
class SSLBaseModule(pl.LightningModule):
|
| 235 |
def __init__(self, config):
|
|
|
|
| 257 |
pre_model.channelfeats.state_dict(), strict=False
|
| 258 |
)
|
| 259 |
|
| 260 |
+
self.vocoder = None
|
| 261 |
self.criteria = self.get_loss_function(config)
|
| 262 |
|
| 263 |
def training_step(self, batch, batch_idx):
|
|
|
|
| 363 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
| 364 |
channels_first=True,
|
| 365 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
|
| 368 |
class SSLStepLightningModule(SSLBaseModule):
|
|
|
|
| 443 |
prog_bar=True,
|
| 444 |
logger=True,
|
| 445 |
)
|
|
|
|
|
|
|
| 446 |
|
| 447 |
def optimizer_step(
|
| 448 |
self,
|
|
|
|
| 684 |
prog_bar=True,
|
| 685 |
logger=True,
|
| 686 |
)
|
|
|
|
|
|
|
| 687 |
|
| 688 |
def test_step(self, batch, batch_idx):
|
| 689 |
if self.config["general"]["use_gst"]:
|
|
|
|
| 761 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
| 762 |
channels_first=True,
|
| 763 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
def configure_optimizers(self):
|
| 766 |
optimizer = torch.optim.Adam(
|
utils.py
CHANGED
|
@@ -3,18 +3,6 @@ import json
|
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
| 5 |
|
| 6 |
-
def load_vocoder(config):
|
| 7 |
-
with open(
|
| 8 |
-
"hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
|
| 9 |
-
) as f:
|
| 10 |
-
config_hifigan = hifigan.AttrDict(json.load(f))
|
| 11 |
-
vocoder = hifigan.Generator(config_hifigan)
|
| 12 |
-
vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
|
| 13 |
-
vocoder.remove_weight_norm()
|
| 14 |
-
for param in vocoder.parameters():
|
| 15 |
-
param.requires_grad = False
|
| 16 |
-
return vocoder
|
| 17 |
-
|
| 18 |
def configure_args(config, args):
|
| 19 |
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
|
| 20 |
if getattr(args, key) != None:
|
|
|
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def configure_args(config, args):
|
| 7 |
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
|
| 8 |
if getattr(args, key) != None:
|