# from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS # def calculate_mean_std_mnist(datamodule:pl.LightningDataModule): # data_loader:TRAIN_DATALOADERS; # mean = torch.zeros(1); # std = torch.zeros(1) # num_samples = 0 # for img in data_loader: # image = img[0] # image = image.squeeze() # mean += image.mean() # mean across channel sum for all pics # std += image.std() # num_samples += 1 # mean /= num_samples # std /= num_samples # return (mean.item(),std.item()) from torchvision import transforms TRAIN_TRANSFORMS = transforms.Compose([ transforms.RandomApply([transforms.CenterCrop(22), ], p=0.1), transforms.RandomAffine(degrees=7, shear=10, translate=(0.1, 0.1), scale=(0.8, 1.2)), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) TEST_TRANSFORMS = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])