File size: 2,431 Bytes
af3a445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Callable,List,Any
from pathlib import Path

from lightning import LightningDataModule
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS
import torch
from torch.utils.data import DataLoader,random_split
from torchvision import transforms
from torchvision.datasets import MNIST

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class LitMNISTDataModule(LightningDataModule):
    def __init__(
            self,
            data_dir:Path = Path('.'),
            batch_size:int = 32,
            num_workers:int = 0,
            test_transform:Callable =  transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(.1307,),std=(.3081,))]),
            train_transform:Callable = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(.1307,),std=(.3081,))])
    ) -> None:
        super().__init__()
        self.data_dir:Path = data_dir
        self.batch_size:int = batch_size
        self.num_workers:int = num_workers
        self.test_transform:Callable = test_transform
        self.train_transform:Callable = train_transform
        self.save_hyperparameters()

    def prepare_data(self) -> None:
        MNIST(self.data_dir,train=True,download=True)
        MNIST(self.data_dir,train=False,download=True)
    
    def setup(self, stage: str=None) -> None:
        if stage=="fit" or stage is None:
            _mnist_full = MNIST(self.data_dir,train=True,transform=self.train_transform)
            self.mnist_train, self.mnist_val = random_split(_mnist_full,[.9,.1],generator=torch.Generator(device))
        
        if stage=='test' or stage is None: 
            self.mnist_test = MNIST(self.data_dir,train=False, transform=self.test_transform)

    
    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self.mnist_train,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=True,generator= torch.Generator(device) )
    
    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.mnist_val,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=False,generator= torch.Generator(device))
    
    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.mnist_test,batch_size=self.batch_size,num_workers=self.num_workers,collate_fn=None,shuffle=False,generator= torch.Generator(device))