Spaces:
Sleeping
Sleeping
File size: 2,431 Bytes
af3a445 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import Callable,List,Any
from pathlib import Path
from lightning import LightningDataModule
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS
import torch
from torch.utils.data import DataLoader,random_split
from torchvision import transforms
from torchvision.datasets import MNIST
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
class LitMNISTDataModule(LightningDataModule):
def __init__(
self,
data_dir:Path = Path('.'),
batch_size:int = 32,
num_workers:int = 0,
test_transform:Callable = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(.1307,),std=(.3081,))]),
train_transform:Callable = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(.1307,),std=(.3081,))])
) -> None:
super().__init__()
self.data_dir:Path = data_dir
self.batch_size:int = batch_size
self.num_workers:int = num_workers
self.test_transform:Callable = test_transform
self.train_transform:Callable = train_transform
self.save_hyperparameters()
def prepare_data(self) -> None:
MNIST(self.data_dir,train=True,download=True)
MNIST(self.data_dir,train=False,download=True)
def setup(self, stage: str=None) -> None:
if stage=="fit" or stage is None:
_mnist_full = MNIST(self.data_dir,train=True,transform=self.train_transform)
self.mnist_train, self.mnist_val = random_split(_mnist_full,[.9,.1],generator=torch.Generator(device))
if stage=='test' or stage is None:
self.mnist_test = MNIST(self.data_dir,train=False, transform=self.test_transform)
def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(self.mnist_train,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=True,generator= torch.Generator(device) )
def val_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.mnist_val,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=False,generator= torch.Generator(device))
def test_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.mnist_test,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=False,generator= torch.Generator(device))
|