import gradio as gr
import torch
from model import *

from PIL import Image
import torchvision.transforms as transforms

title = "Digit Classifier"
description = (
    "Multilayer-Perceptron built for the fast.ai 'Deep Learning' course "
    "to classify handwritten digits from the MNIST dataset. "
)
inputs = gr.components.Image()
outputs = gr.components.Label()
examples = "examples"

model = torch.load("model/digit_classifier.pt", map_location=torch.device("cpu"))
labels = [str(i) for i in range(10)]

transform = transforms.Compose(
    [
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x[0]),
        transforms.Lambda(lambda x: x.unsqueeze(0)),
    ]
)


def predict_digit(img):
    img = transform(Image.fromarray(img))
    output = model(img)
    probs = torch.nn.functional.softmax(output, dim=1)
    return dict(zip(labels, map(float, probs.flatten()[:10])))


with gr.Blocks() as demo:
    with gr.Tab("Digit Prediction"):
        gr.Interface(
            fn=predict_digit,
            inputs=inputs,
            outputs=outputs,
            examples=examples,
            title=title,
            description=description,
        ).queue(default_concurrency_limit=5)

demo.launch()