mattricesound commited on
Commit
80a1624
·
1 Parent(s): 6af03de

Added intermediate logging of chain_inference

Browse files
cfg/config.yaml CHANGED
@@ -117,7 +117,6 @@ trainer:
117
  precision: 32 # Precision used for tensors, default `32`
118
  min_epochs: 0
119
  max_epochs: -1
120
- enable_model_summary: False
121
  log_every_n_steps: 1 # Logs metrics every N batches
122
  accumulate_grad_batches: 1
123
  accelerator: ${accelerator}
 
117
  precision: 32 # Precision used for tensors, default `32`
118
  min_epochs: 0
119
  max_epochs: -1
 
120
  log_every_n_steps: 1 # Logs metrics every N batches
121
  accumulate_grad_batches: 1
122
  accelerator: ${accelerator}
cfg/model/demucs.yaml CHANGED
@@ -13,3 +13,5 @@ model:
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
 
 
 
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
+ channels: 64
17
+
remfx/models.py CHANGED
@@ -11,6 +11,7 @@ from umx.openunmix.model import OpenUnmix, Separator
11
  from remfx.utils import FADLoss, spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
 
14
  from remfx import effects
15
  import asteroid
16
 
@@ -42,12 +43,26 @@ class RemFXChainInference(pl.LightningModule):
42
  ]
43
  output = []
44
  with torch.no_grad():
45
- for elem, effect_chain in zip(x, effects):
46
  elem = elem.unsqueeze(0) # Add batch dim
 
 
 
 
 
 
 
47
  for effect in effect_chain:
48
  # Get correct model based on effect name. This is a bit hacky
49
  # Then sample the model
50
  elem = self.model[effect.__name__].model.sample(elem)
 
 
 
 
 
 
 
51
  output.append(elem.squeeze(0))
52
  output = torch.stack(output)
53
 
 
11
  from remfx.utils import FADLoss, spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
14
+ from remfx.callbacks import log_wandb_audio_batch
15
  from remfx import effects
16
  import asteroid
17
 
 
43
  ]
44
  output = []
45
  with torch.no_grad():
46
+ for i, (elem, effect_chain) in enumerate(zip(x, effects)):
47
  elem = elem.unsqueeze(0) # Add batch dim
48
+ log_wandb_audio_batch(
49
+ logger=self.logger,
50
+ id=f"{i}_Before",
51
+ samples=elem.cpu(),
52
+ sampling_rate=self.sample_rate,
53
+ caption=effect_chain,
54
+ )
55
  for effect in effect_chain:
56
  # Get correct model based on effect name. This is a bit hacky
57
  # Then sample the model
58
  elem = self.model[effect.__name__].model.sample(elem)
59
+ log_wandb_audio_batch(
60
+ logger=self.logger,
61
+ id=f"{i}_{effect}",
62
+ samples=elem.cpu(),
63
+ sampling_rate=self.sample_rate,
64
+ caption=effect_chain,
65
+ )
66
  output.append(elem.squeeze(0))
67
  output = torch.stack(output)
68
 
scripts/chain_inference.py CHANGED
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
2
  import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
- from pytorch_lightning.utilities.model_summary import ModelSummary
6
  import torch
7
  from remfx.models import RemFXChainInference
8
 
@@ -21,10 +20,9 @@ def main(cfg: DictConfig):
21
  for effect in cfg.ckpts:
22
  ckpt_path = cfg.ckpts[effect]
23
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
24
- state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
25
- "state_dict"
26
- ]
27
  model.load_state_dict(state_dict)
 
28
  models[effect] = model
29
 
30
  callbacks = []
@@ -48,8 +46,6 @@ def main(cfg: DictConfig):
48
  callbacks=callbacks,
49
  logger=logger,
50
  )
51
- summary = ModelSummary(model)
52
- print(summary)
53
 
54
  inference_model = RemFXChainInference(
55
  models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
 
2
  import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
 
5
  import torch
6
  from remfx.models import RemFXChainInference
7
 
 
20
  for effect in cfg.ckpts:
21
  ckpt_path = cfg.ckpts[effect]
22
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
23
+ state_dict = torch.load(ckpt_path)["state_dict"]
 
 
24
  model.load_state_dict(state_dict)
25
+ model.to(cfg.device)
26
  models[effect] = model
27
 
28
  callbacks = []
 
46
  callbacks=callbacks,
47
  logger=logger,
48
  )
 
 
49
 
50
  inference_model = RemFXChainInference(
51
  models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
scripts/test.py CHANGED
@@ -44,8 +44,6 @@ def main(cfg: DictConfig):
44
  callbacks=callbacks,
45
  logger=logger,
46
  )
47
- summary = ModelSummary(model)
48
- print(summary)
49
  trainer.test(model=model, datamodule=datamodule)
50
 
51
 
 
44
  callbacks=callbacks,
45
  logger=logger,
46
  )
 
 
47
  trainer.test(model=model, datamodule=datamodule)
48
 
49
 
scripts/train.py CHANGED
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
2
  import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
- from pytorch_lightning.utilities.model_summary import ModelSummary
6
 
7
  log = utils.get_logger(__name__)
8
 
@@ -39,8 +38,6 @@ def main(cfg: DictConfig):
39
  callbacks=callbacks,
40
  logger=logger,
41
  )
42
- summary = ModelSummary(model)
43
- print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
45
  trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
46
 
 
2
  import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
 
5
 
6
  log = utils.get_logger(__name__)
7
 
 
38
  callbacks=callbacks,
39
  logger=logger,
40
  )
 
 
41
  trainer.fit(model=model, datamodule=datamodule)
42
  trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
43