stanimirovb's picture
access model like that, maybe?
2f12319 verified
raw
history blame
1.53 kB
import gradio as gr
import numpy as np
import torch
import torchvision
from torch import nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2),
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2)
)
self.linear = nn.Sequential(
nn.Linear(4*4*12,10)
)
def forward(self, x):
x = self.convs(x)
x = torch.flatten(x, 1)
return self.linear(x)
@torch.no_grad()
def predict(self, input):
input = input.reshape(1, 1, 28, 28)
out = self(input)
return nn.functional.softmax(out[0], dim = 0)
lenet = LeNet()
lenet.load_state_dict(torch.load('ibob-lenet-v1/lenet-v1.pth', map_location='cpu'))
resize = torchvision.transforms.Resize((28, 28), antialias=True)
def on_submit(img):
with torch.no_grad():
img = img['composite'].astype(np.float32)
img = torch.from_numpy(img)
img = resize(img.unsqueeze(0))
result = lenet.predict(img)
sorted = [[i, e] for i, e in enumerate(result.numpy())]
sorted.sort(key = lambda a : -a[1])
return "\n".join(map(str, sorted))
iface = gr.Interface(
title = "LeNet",
fn = on_submit,
inputs=gr.Sketchpad(image_mode='P'),
outputs=gr.Text(),
)
iface.launch()