mattricesound commited on
Commit
898adbc
·
2 Parent(s): 9216427 dfbeb31

Merge pull request #20 from mhrice/fix-demucs

Browse files
Files changed (3) hide show
  1. README.md +4 -1
  2. exp/demucs.yaml +1 -1
  3. remfx/models.py +28 -1
README.md CHANGED
@@ -10,10 +10,13 @@
10
  `./scripts/download_egfx.sh`
11
 
12
  ## Train model
13
- 1. Change Wandb variables in `shell_vars.sh`
14
  2. `python scripts/train.py exp=audio_diffusion`
15
  or
16
  2. `python scripts/train.py exp=umx`
 
 
 
17
 
18
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
19
 
 
10
  `./scripts/download_egfx.sh`
11
 
12
  ## Train model
13
+ 1. Change Wandb variables in `shell_vars.sh` and `source shell_vars.sh`
14
  2. `python scripts/train.py exp=audio_diffusion`
15
  or
16
  2. `python scripts/train.py exp=umx`
17
+ or
18
+ 2. `python scripts/train.py exp=demucs`
19
+
20
 
21
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
22
 
exp/demucs.yaml CHANGED
@@ -9,7 +9,7 @@ model:
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DemucsModel
12
- sources: ["other"]
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
 
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DemucsModel
12
+ sources: ["mixture"]
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
remfx/models.py CHANGED
@@ -38,6 +38,8 @@ class RemFXModel(pl.LightningModule):
38
  "L1": L1Loss(),
39
  }
40
  )
 
 
41
 
42
  @property
43
  def device(self):
@@ -67,9 +69,14 @@ class RemFXModel(pl.LightningModule):
67
  x, y, label = batch
68
  # Metric logging
69
  for metric in self.metrics:
 
 
 
 
 
70
  self.log(
71
  f"{mode}_{metric}",
72
- self.metrics[metric](output, y),
73
  on_step=False,
74
  on_epoch=True,
75
  logger=True,
@@ -79,6 +86,26 @@ class RemFXModel(pl.LightningModule):
79
 
80
  return loss
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def on_validation_epoch_start(self):
83
  self.log_next = True
84
 
 
38
  "L1": L1Loss(),
39
  }
40
  )
41
+ # Log first batch metrics input vs output only once
42
+ self.log_first = True
43
 
44
  @property
45
  def device(self):
 
69
  x, y, label = batch
70
  # Metric logging
71
  for metric in self.metrics:
72
+ # SISDR returns negative values, so negate them
73
+ if metric == "SISDR":
74
+ negate = -1
75
+ else:
76
+ negate = 1
77
  self.log(
78
  f"{mode}_{metric}",
79
+ negate * self.metrics[metric](output, y),
80
  on_step=False,
81
  on_epoch=True,
82
  logger=True,
 
86
 
87
  return loss
88
 
89
+ def on_train_batch_start(self, batch, batch_idx):
90
+ if self.log_first:
91
+ x, target, label = batch
92
+ for metric in self.metrics:
93
+ # SISDR returns negative values, so negate them
94
+ if metric == "SISDR":
95
+ negate = -1
96
+ else:
97
+ negate = 1
98
+ self.log(
99
+ f"Input_{metric}",
100
+ negate * self.metrics[metric](x, target),
101
+ on_step=False,
102
+ on_epoch=True,
103
+ logger=True,
104
+ prog_bar=True,
105
+ sync_dist=True,
106
+ )
107
+ self.log_first = False
108
+
109
  def on_validation_epoch_start(self):
110
  self.log_next = True
111