Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from torchvision import transforms | |
from PIL import ImageOps | |
def load_model(): | |
model_dict = torch.load('gradio-app/linear_model.pt') | |
return model_dict | |
model = load_model() | |
convert_tensor = transforms.ToTensor() | |
def predict(img): | |
img = ImageOps.grayscale(img).resize((28,28)) | |
image_tensor = convert_tensor(img).view(28*28) | |
res = image_tensor @ model['weights'] + model['bias'] | |
res = res.sigmoid() | |
return {"It's 3": float(res), "It's 7": float(1-res)} | |
title = "Is it 7 or 3" | |
description = '<p><center>Write a number, 7 or 3, in the middle.</center></p>' | |
gr.Interface(fn=predict, | |
inputs=gr.Paint(type="pil", invert_colors=True), | |
outputs=gr.Label(num_top_classes=2), | |
title=title, | |
description=description, | |
allow_flagging='never').launch() |