Spaces:
Sleeping
Sleeping
# from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS | |
# def calculate_mean_std_mnist(datamodule:pl.LightningDataModule): | |
# data_loader:TRAIN_DATALOADERS; | |
# mean = torch.zeros(1); | |
# std = torch.zeros(1) | |
# num_samples = 0 | |
# for img in data_loader: | |
# image = img[0] | |
# image = image.squeeze() | |
# mean += image.mean() # mean across channel sum for all pics | |
# std += image.std() | |
# num_samples += 1 | |
# mean /= num_samples | |
# std /= num_samples | |
# return (mean.item(),std.item()) | |
from torchvision import transforms | |
TRAIN_TRANSFORMS = transforms.Compose([ | |
transforms.RandomApply([transforms.CenterCrop(22), ], p=0.1), | |
transforms.RandomAffine(degrees=7, shear=10, translate=(0.1, 0.1), scale=(0.8, 1.2)), | |
transforms.Resize((28, 28)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)), | |
]) | |
TEST_TRANSFORMS = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) |