<a href="https://colab.research.google.com/drive/1i94k-n97Z5r1KWV9Vly9IiKnYxf3Tfvu?usp=sharing" target="_blank"><img align="left" alt="Colab" title="Open in Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a>

## First Neural Network: Image Classification 

Objectives:
- Train a minimal image classifier on [MNIST](https://paperswithcode.com/dataset/mnist) using PyTorch
- Usese PyTorch and torchvision

In [None]:
# The usual imports

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [None]:
# load the data

class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        return torch.reshape(img, self.new_size)

transformations = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.ConvertImageDtype(torch.float32),
                                ReshapeTransform((-1,))
                                ])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transformations)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transformations)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
# check shape of data

trainset.data.shape, testset.data.shape

(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))

In [None]:
# data loader

BATCH_SIZE = 128
train_dataloader = torch.utils.data.DataLoader(trainset, 
                                               batch_size=BATCH_SIZE,
                                               shuffle=True, 
                                               num_workers=0)

test_dataloader = torch.utils.data.DataLoader(testset, 
                                              batch_size=BATCH_SIZE,
                                              shuffle=False, 
                                              num_workers=0)

In [None]:
# model

model = nn.Sequential(nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 10))

In [None]:
# training preparation

trainer = torch.optim.RMSprop(model.parameters())
loss = nn.CrossEntropyLoss()

In [None]:
def get_accuracy(output, target, batch_size):
    # Obtain accuracy for training round
    corrects = (torch.max(output, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

In [None]:
# train

for ITER in range(5):
    train_acc = 0.0
    train_running_loss = 0.0

    model.train()
    for i, (X, y) in enumerate(train_dataloader):
        output = model(X)
        l = loss(output, y)

        # update the parameters
        l.backward()
        trainer.step()
        trainer.zero_grad()

        # gather metrics
        train_acc += get_accuracy(output, y, BATCH_SIZE)
        train_running_loss += l.detach().item()

    print('Epoch: %d | Train loss: %.4f | Train Accuracy: %.4f' \
          %(ITER+1, train_running_loss / (i+1),train_acc/(i+1)))

Epoch: 1 | Train loss: 1.0415 | Train Accuracy: 91.9010
Epoch: 2 | Train loss: 0.1291 | Train Accuracy: 96.0871
Epoch: 3 | Train loss: 0.0997 | Train Accuracy: 97.0399
Epoch: 4 | Train loss: 0.0865 | Train Accuracy: 97.4913
Epoch: 5 | Train loss: 0.0740 | Train Accuracy: 97.8611


### Other things to try

- Evaluate on test set
- Plot loss curve
- Add more layers to the model