Spaces:
Sleeping
Sleeping
Commit
·
80a1624
1
Parent(s):
6af03de
Added intermediate logging of chain_inference
Browse files- cfg/config.yaml +0 -1
- cfg/model/demucs.yaml +2 -0
- remfx/models.py +16 -1
- scripts/chain_inference.py +2 -6
- scripts/test.py +0 -2
- scripts/train.py +0 -3
cfg/config.yaml
CHANGED
@@ -117,7 +117,6 @@ trainer:
|
|
117 |
precision: 32 # Precision used for tensors, default `32`
|
118 |
min_epochs: 0
|
119 |
max_epochs: -1
|
120 |
-
enable_model_summary: False
|
121 |
log_every_n_steps: 1 # Logs metrics every N batches
|
122 |
accumulate_grad_batches: 1
|
123 |
accelerator: ${accelerator}
|
|
|
117 |
precision: 32 # Precision used for tensors, default `32`
|
118 |
min_epochs: 0
|
119 |
max_epochs: -1
|
|
|
120 |
log_every_n_steps: 1 # Logs metrics every N batches
|
121 |
accumulate_grad_batches: 1
|
122 |
accelerator: ${accelerator}
|
cfg/model/demucs.yaml
CHANGED
@@ -13,3 +13,5 @@ model:
|
|
13 |
audio_channels: 1
|
14 |
nfft: 4096
|
15 |
sample_rate: ${sample_rate}
|
|
|
|
|
|
13 |
audio_channels: 1
|
14 |
nfft: 4096
|
15 |
sample_rate: ${sample_rate}
|
16 |
+
channels: 64
|
17 |
+
|
remfx/models.py
CHANGED
@@ -11,6 +11,7 @@ from umx.openunmix.model import OpenUnmix, Separator
|
|
11 |
from remfx.utils import FADLoss, spectrogram
|
12 |
from remfx.tcn import TCN
|
13 |
from remfx.utils import causal_crop
|
|
|
14 |
from remfx import effects
|
15 |
import asteroid
|
16 |
|
@@ -42,12 +43,26 @@ class RemFXChainInference(pl.LightningModule):
|
|
42 |
]
|
43 |
output = []
|
44 |
with torch.no_grad():
|
45 |
-
for elem, effect_chain in zip(x, effects):
|
46 |
elem = elem.unsqueeze(0) # Add batch dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
for effect in effect_chain:
|
48 |
# Get correct model based on effect name. This is a bit hacky
|
49 |
# Then sample the model
|
50 |
elem = self.model[effect.__name__].model.sample(elem)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
output.append(elem.squeeze(0))
|
52 |
output = torch.stack(output)
|
53 |
|
|
|
11 |
from remfx.utils import FADLoss, spectrogram
|
12 |
from remfx.tcn import TCN
|
13 |
from remfx.utils import causal_crop
|
14 |
+
from remfx.callbacks import log_wandb_audio_batch
|
15 |
from remfx import effects
|
16 |
import asteroid
|
17 |
|
|
|
43 |
]
|
44 |
output = []
|
45 |
with torch.no_grad():
|
46 |
+
for i, (elem, effect_chain) in enumerate(zip(x, effects)):
|
47 |
elem = elem.unsqueeze(0) # Add batch dim
|
48 |
+
log_wandb_audio_batch(
|
49 |
+
logger=self.logger,
|
50 |
+
id=f"{i}_Before",
|
51 |
+
samples=elem.cpu(),
|
52 |
+
sampling_rate=self.sample_rate,
|
53 |
+
caption=effect_chain,
|
54 |
+
)
|
55 |
for effect in effect_chain:
|
56 |
# Get correct model based on effect name. This is a bit hacky
|
57 |
# Then sample the model
|
58 |
elem = self.model[effect.__name__].model.sample(elem)
|
59 |
+
log_wandb_audio_batch(
|
60 |
+
logger=self.logger,
|
61 |
+
id=f"{i}_{effect}",
|
62 |
+
samples=elem.cpu(),
|
63 |
+
sampling_rate=self.sample_rate,
|
64 |
+
caption=effect_chain,
|
65 |
+
)
|
66 |
output.append(elem.squeeze(0))
|
67 |
output = torch.stack(output)
|
68 |
|
scripts/chain_inference.py
CHANGED
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
|
|
2 |
import hydra
|
3 |
from omegaconf import DictConfig
|
4 |
import remfx.utils as utils
|
5 |
-
from pytorch_lightning.utilities.model_summary import ModelSummary
|
6 |
import torch
|
7 |
from remfx.models import RemFXChainInference
|
8 |
|
@@ -21,10 +20,9 @@ def main(cfg: DictConfig):
|
|
21 |
for effect in cfg.ckpts:
|
22 |
ckpt_path = cfg.ckpts[effect]
|
23 |
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
24 |
-
state_dict = torch.load(ckpt_path
|
25 |
-
"state_dict"
|
26 |
-
]
|
27 |
model.load_state_dict(state_dict)
|
|
|
28 |
models[effect] = model
|
29 |
|
30 |
callbacks = []
|
@@ -48,8 +46,6 @@ def main(cfg: DictConfig):
|
|
48 |
callbacks=callbacks,
|
49 |
logger=logger,
|
50 |
)
|
51 |
-
summary = ModelSummary(model)
|
52 |
-
print(summary)
|
53 |
|
54 |
inference_model = RemFXChainInference(
|
55 |
models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
|
|
|
2 |
import hydra
|
3 |
from omegaconf import DictConfig
|
4 |
import remfx.utils as utils
|
|
|
5 |
import torch
|
6 |
from remfx.models import RemFXChainInference
|
7 |
|
|
|
20 |
for effect in cfg.ckpts:
|
21 |
ckpt_path = cfg.ckpts[effect]
|
22 |
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
23 |
+
state_dict = torch.load(ckpt_path)["state_dict"]
|
|
|
|
|
24 |
model.load_state_dict(state_dict)
|
25 |
+
model.to(cfg.device)
|
26 |
models[effect] = model
|
27 |
|
28 |
callbacks = []
|
|
|
46 |
callbacks=callbacks,
|
47 |
logger=logger,
|
48 |
)
|
|
|
|
|
49 |
|
50 |
inference_model = RemFXChainInference(
|
51 |
models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
|
scripts/test.py
CHANGED
@@ -44,8 +44,6 @@ def main(cfg: DictConfig):
|
|
44 |
callbacks=callbacks,
|
45 |
logger=logger,
|
46 |
)
|
47 |
-
summary = ModelSummary(model)
|
48 |
-
print(summary)
|
49 |
trainer.test(model=model, datamodule=datamodule)
|
50 |
|
51 |
|
|
|
44 |
callbacks=callbacks,
|
45 |
logger=logger,
|
46 |
)
|
|
|
|
|
47 |
trainer.test(model=model, datamodule=datamodule)
|
48 |
|
49 |
|
scripts/train.py
CHANGED
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
|
|
2 |
import hydra
|
3 |
from omegaconf import DictConfig
|
4 |
import remfx.utils as utils
|
5 |
-
from pytorch_lightning.utilities.model_summary import ModelSummary
|
6 |
|
7 |
log = utils.get_logger(__name__)
|
8 |
|
@@ -39,8 +38,6 @@ def main(cfg: DictConfig):
|
|
39 |
callbacks=callbacks,
|
40 |
logger=logger,
|
41 |
)
|
42 |
-
summary = ModelSummary(model)
|
43 |
-
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
45 |
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
|
46 |
|
|
|
2 |
import hydra
|
3 |
from omegaconf import DictConfig
|
4 |
import remfx.utils as utils
|
|
|
5 |
|
6 |
log = utils.get_logger(__name__)
|
7 |
|
|
|
38 |
callbacks=callbacks,
|
39 |
logger=logger,
|
40 |
)
|
|
|
|
|
41 |
trainer.fit(model=model, datamodule=datamodule)
|
42 |
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
|
43 |
|