Spaces:
Sleeping
Sleeping
File size: 1,055 Bytes
14ae0ea 8949a8c 14ae0ea 8949a8c 14ae0ea 8949a8c 14ae0ea 8949a8c 14ae0ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from datasets import GuitarFXDataset
from models import DiffusionGenerationModel, OpenUnmixModel
SAMPLE_RATE = 22050
TRAIN_SPLIT = 0.8
def main():
wandb_logger = WandbLogger(project="RemFX", save_dir="./")
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
guitfx = GuitarFXDataset(
root="/Users/matthewrice/Developer/remfx/data/egfx",
sample_rate=SAMPLE_RATE,
effect_type=["Phaser"],
)
train_size = int(TRAIN_SPLIT * len(guitfx))
val_size = len(guitfx) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
guitfx, [train_size, val_size]
)
train = DataLoader(train_dataset, batch_size=2)
val = DataLoader(val_dataset, batch_size=2)
# model = DiffusionGenerationModel()
model = OpenUnmixModel()
trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
if __name__ == "__main__":
main()
|