Spaces:
Sleeping
Sleeping
Commit
·
a040dee
1
Parent(s):
9eba2f5
Add custom effect ordering during chain inference
Browse files- remfx/models.py +3 -4
- scripts/chain_inference.py +7 -1
remfx/models.py
CHANGED
@@ -52,7 +52,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
52 |
with torch.no_grad():
|
53 |
for i, (elem, effects_list) in enumerate(zip(x, effects)):
|
54 |
elem = elem.unsqueeze(0) # Add batch dim
|
55 |
-
|
56 |
effects_order.index(effect.__name__) for effect in effects_list
|
57 |
]
|
58 |
# log_wandb_audio_batch(
|
@@ -62,10 +62,9 @@ class RemFXChainInference(pl.LightningModule):
|
|
62 |
# sampling_rate=self.sample_rate,
|
63 |
# caption=effect_chain,
|
64 |
# )
|
65 |
-
|
66 |
-
for effect in effect_chain:
|
67 |
# Sample the model
|
68 |
-
elem = self.model[
|
69 |
# log_wandb_audio_batch(
|
70 |
# logger=self.logger,
|
71 |
# id=f"{i}_{effect}",
|
|
|
52 |
with torch.no_grad():
|
53 |
for i, (elem, effects_list) in enumerate(zip(x, effects)):
|
54 |
elem = elem.unsqueeze(0) # Add batch dim
|
55 |
+
effect_chain_idx = [
|
56 |
effects_order.index(effect.__name__) for effect in effects_list
|
57 |
]
|
58 |
# log_wandb_audio_batch(
|
|
|
62 |
# sampling_rate=self.sample_rate,
|
63 |
# caption=effect_chain,
|
64 |
# )
|
65 |
+
for idx in effect_chain_idx:
|
|
|
66 |
# Sample the model
|
67 |
+
elem = self.model[effects_order[idx]].model.sample(elem)
|
68 |
# log_wandb_audio_batch(
|
69 |
# logger=self.logger,
|
70 |
# id=f"{i}_{effect}",
|
scripts/chain_inference.py
CHANGED
@@ -51,7 +51,13 @@ def main(cfg: DictConfig):
|
|
51 |
models,
|
52 |
sample_rate=cfg.sample_rate,
|
53 |
num_bins=cfg.num_bins,
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
trainer.test(model=inference_model, datamodule=datamodule)
|
57 |
|
|
|
51 |
models,
|
52 |
sample_rate=cfg.sample_rate,
|
53 |
num_bins=cfg.num_bins,
|
54 |
+
effect_order=[
|
55 |
+
"RandomPedalboardDistortion",
|
56 |
+
"RandomPedalboardCompressor",
|
57 |
+
"RandomPedalboardReverb",
|
58 |
+
"RandomPedalboardChorus",
|
59 |
+
"RandomPedalboardDelay",
|
60 |
+
],
|
61 |
)
|
62 |
trainer.test(model=inference_model, datamodule=datamodule)
|
63 |
|