File size: 3,479 Bytes
14fef9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import torch
import gradio as gr
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# This is just to show an interface where one draws a number and gets prediction. 

n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
TRAIN_CUTOFF = 10
MODEL_PATH = 'model' 
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth')
OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth')
REPOSITORY_DIR = "data"
LOCAL_DIR = 'data_local'




HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = 'mnist-adversarial-model'
HF_DATASET ="mnist-adversarial-dataset"
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
MODEL_REPO_URL = f"https://huggingface.co/model/chrisjay/{MODEL_REPO}"


torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)



TRAIN_TRANSFORM = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])



# Source: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
    def __init__(self):
        super(MNIST_Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)




random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

network = MNIST_Model() #Initialize the model with random weights
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)        


# Train
#train(n_epochs,network,optimizer)


def image_classifier(inp):
    """
    It takes an image as input and returns a dictionary of class labels and their corresponding
    confidence scores.
    
    :param inp: the image to be classified
    :return: A dictionary of the class index and the confidence value.
    """
    input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
    with torch.no_grad():

        prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
        #pred_number = prediction.data.max(1, keepdim=True)[1]
        sorted_prediction = torch.sort(prediction,descending=True)
        confidences={}
        for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
            confidences.update({s:v})
        return confidences




def main():
    block = gr.Blocks()

    with block:

        with gr.Row():     
    

            image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
            label_output = gr.outputs.Label(num_top_classes=2)
        
        image_input.change(image_classifier,inputs = [image_input],outputs=[label_output])
        


    block.launch()  
        
     


if __name__ == "__main__":
    main()