Spaces:
Running
on
Zero
Running
on
Zero
| import pytorch_lightning as pl | |
| from torch.utils.data import DataLoader | |
| from dataset import MyDataset, load_filenames # dataset.pyに基づく | |
| class DataModule(pl.LightningDataModule): | |
| def __init__(self, img_dir, batch_size, img_size=112, num_workers=0): | |
| super().__init__() | |
| self.img_dir = img_dir | |
| self.batch_size = batch_size | |
| self.img_size = img_size | |
| self.num_workers = num_workers | |
| self.file_num = 1000 # or 3400 | |
| def setup(self, stage=None): | |
| filenames = load_filenames(self.img_dir) | |
| self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| persistent_workers=True | |
| ) | |