UnsolvedMNIST / utils /utils.py
Muthukamalan's picture
src file added
af3a445
# from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS
# def calculate_mean_std_mnist(datamodule:pl.LightningDataModule):
# data_loader:TRAIN_DATALOADERS;
# mean = torch.zeros(1);
# std = torch.zeros(1)
# num_samples = 0
# for img in data_loader:
# image = img[0]
# image = image.squeeze()
# mean += image.mean() # mean across channel sum for all pics
# std += image.std()
# num_samples += 1
# mean /= num_samples
# std /= num_samples
# return (mean.item(),std.item())
from torchvision import transforms
TRAIN_TRANSFORMS = transforms.Compose([
transforms.RandomApply([transforms.CenterCrop(22), ], p=0.1),
transforms.RandomAffine(degrees=7, shear=10, translate=(0.1, 0.1), scale=(0.8, 1.2)),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
TEST_TRANSFORMS = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])