File size: 1,757 Bytes
aa56c48
a85ef54
aa56c48
 
ea81160
 
 
0f7c18e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea81160
 
 
a85ef54
aa56c48
0f7c18e
 
 
 
a85ef54
aa56c48
 
 
 
5940022
0f7c18e
 
 
 
aa56c48
0f7c18e
 
aa56c48
 
5940022
0f7c18e
aa56c48
ea81160
 
aa56c48
0f7c18e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import gradio as gr
from torchvision import transforms
from PIL import ImageOps
import os
from dotenv import load_dotenv

from torch import nn
import torch.nn.functional as F

class SimpleLenet(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28
        self.pool = nn.MaxPool2d(2) # -> 6 channels, 14x14
        self.conv2 = nn.Conv2d(6, 120, 14) #-> 120 channels, 1x1
        self.fc1 = nn.Linear(120, 10)
        self.fc2 = nn.Linear(10, 10)

    def __call__(self, x):
        xx = F.relu(self.conv1(x))
        xx = F.relu(self.pool(xx))
        xx = F.relu(self.conv2(xx))
        xx = xx.flatten(1)
        xx = F.relu(self.fc1(xx))
        return self.fc2(xx)

load_dotenv()

hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")

def load_model():
    model = SimpleLenet()
    model.load_state_dict(torch.load('model.pt'))
    model.eval()
    return model

model = load_model()
convert_tensor = transforms.ToTensor()

def predict(img):
    img =  ImageOps.grayscale(img).resize((28,28))
    image_tensor = convert_tensor(img).view(1, 1, 28, 28)
    logits = model(image_tensor)
    pred = torch.argmax(logits, dim=1)
    return pred.tolist()[0]

title = "Handwritten digit recognition"
description = '<p><center>Write a single digit in the middle of the canvas</center></p>'

gr.Interface(fn=predict, 
             inputs=gr.Paint(type="pil", invert_colors=True),
             outputs="text",
             title=title,
             flagging_options=["incorrect","ambiguous"],
             flagging_callback=hf_writer,
             description=description,
             allow_flagging='manual').launch()