mattricesound commited on
Commit
a040dee
·
1 Parent(s): 9eba2f5

Add custom effect ordering during chain inference

Browse files
Files changed (2) hide show
  1. remfx/models.py +3 -4
  2. scripts/chain_inference.py +7 -1
remfx/models.py CHANGED
@@ -52,7 +52,7 @@ class RemFXChainInference(pl.LightningModule):
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
- effect_chain = [
56
  effects_order.index(effect.__name__) for effect in effects_list
57
  ]
58
  # log_wandb_audio_batch(
@@ -62,10 +62,9 @@ class RemFXChainInference(pl.LightningModule):
62
  # sampling_rate=self.sample_rate,
63
  # caption=effect_chain,
64
  # )
65
- effect_chain
66
- for effect in effect_chain:
67
  # Sample the model
68
- elem = self.model[effect].model.sample(elem)
69
  # log_wandb_audio_batch(
70
  # logger=self.logger,
71
  # id=f"{i}_{effect}",
 
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
+ effect_chain_idx = [
56
  effects_order.index(effect.__name__) for effect in effects_list
57
  ]
58
  # log_wandb_audio_batch(
 
62
  # sampling_rate=self.sample_rate,
63
  # caption=effect_chain,
64
  # )
65
+ for idx in effect_chain_idx:
 
66
  # Sample the model
67
+ elem = self.model[effects_order[idx]].model.sample(elem)
68
  # log_wandb_audio_batch(
69
  # logger=self.logger,
70
  # id=f"{i}_{effect}",
scripts/chain_inference.py CHANGED
@@ -51,7 +51,13 @@ def main(cfg: DictConfig):
51
  models,
52
  sample_rate=cfg.sample_rate,
53
  num_bins=cfg.num_bins,
54
- order=["Distortion", "Compressor", "Reverb", "Chorus", "Delay"],
 
 
 
 
 
 
55
  )
56
  trainer.test(model=inference_model, datamodule=datamodule)
57
 
 
51
  models,
52
  sample_rate=cfg.sample_rate,
53
  num_bins=cfg.num_bins,
54
+ effect_order=[
55
+ "RandomPedalboardDistortion",
56
+ "RandomPedalboardCompressor",
57
+ "RandomPedalboardReverb",
58
+ "RandomPedalboardChorus",
59
+ "RandomPedalboardDelay",
60
+ ],
61
  )
62
  trainer.test(model=inference_model, datamodule=datamodule)
63